mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
Merge remote-tracking branch 'origin/main' into feat_rag_graph
This commit is contained in:
commit
f6694d95ec
@ -64,6 +64,20 @@ class WorkerApplyOutput:
|
||||
# The seconds cost to apply some action to worker instances
|
||||
timecost: Optional[int] = -1
|
||||
|
||||
@staticmethod
|
||||
def reduce(outs: List["WorkerApplyOutput"]) -> "WorkerApplyOutput":
|
||||
"""Merge all outputs
|
||||
|
||||
Args:
|
||||
outs (List["WorkerApplyOutput"]): The list of WorkerApplyOutput
|
||||
"""
|
||||
if not outs:
|
||||
return WorkerApplyOutput("Not outputs")
|
||||
combined_success = all(out.success for out in outs)
|
||||
max_timecost = max(out.timecost for out in outs)
|
||||
combined_message = ", ".join(out.message for out in outs)
|
||||
return WorkerApplyOutput(combined_message, combined_success, max_timecost)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupportedModel:
|
||||
|
@ -185,7 +185,7 @@ def run_model_controller():
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=controller_params.log_level,
|
||||
logger_filename="dbgpt_model_controller.log",
|
||||
logger_filename=controller_params.log_file,
|
||||
)
|
||||
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
|
@ -26,11 +26,22 @@ class WorkerRunData:
|
||||
_heartbeat_future: Optional[Future] = None
|
||||
_last_heartbeat: Optional[datetime] = None
|
||||
|
||||
def _to_print_key(self):
|
||||
model_name = self.model_params.model_name
|
||||
model_type = self.model_params.model_type
|
||||
host = self.host
|
||||
port = self.port
|
||||
return f"model {model_name}@{model_type}({host}:{port})"
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
"""Start worker manager"""
|
||||
"""Start worker manager
|
||||
|
||||
Raises:
|
||||
Exception: if start worker manager not successfully
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self, ignore_exception: bool = False):
|
||||
@ -69,11 +80,11 @@ class WorkerManager(ABC):
|
||||
"""List supported models"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
"""Create and start a model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
"""Shutdown model instance"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -104,12 +104,16 @@ class LocalWorkerManager(WorkerManager):
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
async def run_blocking_func(self, func, *args):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
raise ValueError(f"The function {func} is not blocking function")
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, func, *args)
|
||||
|
||||
async def start(self):
|
||||
if len(self.workers) > 0:
|
||||
await self._start_all_worker(apply_req=None)
|
||||
out = await self._start_all_worker(apply_req=None)
|
||||
if not out.success:
|
||||
raise Exception(out.message)
|
||||
if self.register_func:
|
||||
await self.register_func(self.run_data)
|
||||
if self.send_heartbeat_func:
|
||||
@ -143,7 +147,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
else:
|
||||
stop_tasks.append(self.deregister_func(self.run_data))
|
||||
|
||||
await asyncio.gather(*stop_tasks)
|
||||
results = await asyncio.gather(*stop_tasks)
|
||||
if not results[0].success and not ignore_exception:
|
||||
raise Exception(results[0].message)
|
||||
|
||||
def after_start(self, listener: Callable[["WorkerManager"], None]):
|
||||
self.start_listeners.append(listener)
|
||||
@ -193,7 +199,15 @@ class LocalWorkerManager(WorkerManager):
|
||||
logger.warn(f"Instance {worker_key} exist")
|
||||
return False
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
def _remove_worker(self, worker_params: ModelWorkerParameters) -> None:
|
||||
worker_key = self._worker_key(
|
||||
worker_params.worker_type, worker_params.model_name
|
||||
)
|
||||
instances = self.workers.get(worker_key)
|
||||
if instances:
|
||||
del self.workers[worker_key]
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
"""Start model"""
|
||||
model_name = startup_req.model
|
||||
worker_type = startup_req.worker_type
|
||||
@ -213,22 +227,30 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.add_worker, worker, worker_params, command_args
|
||||
)
|
||||
if not success:
|
||||
logger.warn(
|
||||
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
|
||||
)
|
||||
return False
|
||||
msg = f"Add worker {model_name}@{worker_type}, worker instances is exist"
|
||||
logger.warn(f"{msg}, worker_params: {worker_params}")
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(msg)
|
||||
supported_types = WorkerType.values()
|
||||
if worker_type not in supported_types:
|
||||
self._remove_worker(worker_params)
|
||||
raise ValueError(
|
||||
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
|
||||
)
|
||||
start_apply_req = WorkerApplyRequest(
|
||||
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
|
||||
)
|
||||
await self.worker_apply(start_apply_req)
|
||||
return True
|
||||
out: WorkerApplyOutput = None
|
||||
try:
|
||||
out = await self.worker_apply(start_apply_req)
|
||||
except Exception as e:
|
||||
self._remove_worker(worker_params)
|
||||
raise e
|
||||
if not out.success:
|
||||
self._remove_worker(worker_params)
|
||||
raise Exception(out.message)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=shutdown_req.model,
|
||||
@ -236,8 +258,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_type=shutdown_req.worker_type,
|
||||
)
|
||||
out = await self._stop_all_worker(apply_req)
|
||||
if out.success:
|
||||
return True
|
||||
if not out.success:
|
||||
raise Exception(out.message)
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
@ -253,7 +274,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
return self.workers.get(worker_key)
|
||||
return self.workers.get(worker_key, [])
|
||||
|
||||
def _simple_select(
|
||||
self, worker_type: str, model_name: str, worker_instances: List[WorkerRunData]
|
||||
@ -424,10 +445,15 @@ class LocalWorkerManager(WorkerManager):
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
# TODO avoid start twice
|
||||
start_time = time.time()
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
|
||||
async def _start_worker(worker_run_data: WorkerRunData):
|
||||
_start_time = time.time()
|
||||
info = worker_run_data._to_print_key()
|
||||
out = WorkerApplyOutput("")
|
||||
try:
|
||||
await self.run_blocking_func(
|
||||
worker_run_data.worker.start,
|
||||
worker_run_data.model_params,
|
||||
@ -448,12 +474,18 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.send_heartbeat_func,
|
||||
)
|
||||
)
|
||||
out.message = f"{info} start successfully"
|
||||
except Exception as e:
|
||||
out.success = False
|
||||
out.message = f"{info} start failed, {str(e)}"
|
||||
finally:
|
||||
out.timecost = time.time() - _start_time
|
||||
return out
|
||||
|
||||
await self._apply_worker(apply_req, _start_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker started successfully", timecost=timecost
|
||||
)
|
||||
outs = await self._apply_worker(apply_req, _start_worker)
|
||||
out = WorkerApplyOutput.reduce(outs)
|
||||
out.timecost = time.time() - start_time
|
||||
return out
|
||||
|
||||
async def _stop_all_worker(
|
||||
self, apply_req: WorkerApplyRequest, ignore_exception: bool = False
|
||||
@ -461,6 +493,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
start_time = time.time()
|
||||
|
||||
async def _stop_worker(worker_run_data: WorkerRunData):
|
||||
_start_time = time.time()
|
||||
info = worker_run_data._to_print_key()
|
||||
out = WorkerApplyOutput("")
|
||||
try:
|
||||
await self.run_blocking_func(worker_run_data.worker.stop)
|
||||
# Set stop event
|
||||
worker_run_data.stop_event.set()
|
||||
@ -486,17 +522,27 @@ class LocalWorkerManager(WorkerManager):
|
||||
|
||||
_deregister_func = safe_deregister_func
|
||||
await _deregister_func(worker_run_data)
|
||||
# Remove metadata
|
||||
self._remove_worker(worker_run_data.worker_params)
|
||||
out.message = f"{info} stop successfully"
|
||||
except Exception as e:
|
||||
out.success = False
|
||||
out.message = f"{info} stop failed, {str(e)}"
|
||||
finally:
|
||||
out.timecost = time.time() - _start_time
|
||||
return out
|
||||
|
||||
await self._apply_worker(apply_req, _stop_worker)
|
||||
timecost = time.time() - start_time
|
||||
return WorkerApplyOutput(
|
||||
message=f"Worker stopped successfully", timecost=timecost
|
||||
)
|
||||
outs = await self._apply_worker(apply_req, _stop_worker)
|
||||
out = WorkerApplyOutput.reduce(outs)
|
||||
out.timecost = time.time() - start_time
|
||||
return out
|
||||
|
||||
async def _restart_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
) -> WorkerApplyOutput:
|
||||
await self._stop_all_worker(apply_req)
|
||||
out = await self._stop_all_worker(apply_req, ignore_exception=True)
|
||||
if not out.success:
|
||||
return out
|
||||
return await self._start_all_worker(apply_req)
|
||||
|
||||
async def _update_all_worker_params(
|
||||
@ -541,10 +587,10 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
return await self.worker_manager.supported_models()
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
return await self.worker_manager.model_startup(startup_req)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
return await self.worker_manager.model_shutdown(shutdown_req)
|
||||
|
||||
async def get_model_instances(
|
||||
@ -960,7 +1006,7 @@ def run_worker_manager(
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=worker_params.log_level,
|
||||
logger_filename="dbgpt_model_worker_manager.log",
|
||||
logger_filename=worker_params.log_file,
|
||||
)
|
||||
|
||||
embedded_mod = True
|
||||
@ -973,7 +1019,7 @@ def run_worker_manager(
|
||||
system_app = SystemApp(app)
|
||||
initialize_tracer(
|
||||
system_app,
|
||||
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
|
||||
os.path.join(LOGDIR, worker_params.tracer_file),
|
||||
root_operation_name="DB-GPT-WorkerManager-Entry",
|
||||
)
|
||||
|
||||
|
@ -96,7 +96,7 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
raise Exception(error_msg)
|
||||
return worker_instances
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest):
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
startup_req.host, startup_req.port
|
||||
)
|
||||
@ -107,10 +107,10 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
"/models/startup",
|
||||
method="POST",
|
||||
json=startup_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
success_handler=lambda x: None,
|
||||
)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest):
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
shutdown_req.host, shutdown_req.port
|
||||
)
|
||||
@ -121,7 +121,7 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
"/models/shutdown",
|
||||
method="POST",
|
||||
json=shutdown_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
success_handler=lambda x: None,
|
||||
)
|
||||
|
||||
def _build_worker_instances(
|
||||
|
0
pilot/model/cluster/worker/tests/__init__.py
Normal file
0
pilot/model/cluster/worker/tests/__init__.py
Normal file
168
pilot/model/cluster/worker/tests/base_tests.py
Normal file
168
pilot/model/cluster/worker/tests/base_tests.py
Normal file
@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
|
||||
|
||||
class MockModelWorker(ModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
model_parameters: ModelParameters,
|
||||
error_worker: bool = False,
|
||||
stop_error: bool = False,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not stream_messags:
|
||||
stream_messags = []
|
||||
if not embeddings:
|
||||
embeddings = []
|
||||
self.model_parameters = model_parameters
|
||||
self.error_worker = error_worker
|
||||
self.stop_error = stop_error
|
||||
self.stream_messags = stream_messags
|
||||
self._embeddings = embeddings
|
||||
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
return self.model_parameters
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def start(
|
||||
self, model_params: ModelParameters = None, command_args: List[str] = None
|
||||
) -> None:
|
||||
if self.error_worker:
|
||||
raise Exception("Start worker error for mock")
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.stop_error:
|
||||
raise Exception("Stop worker error for mock")
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
for msg in self.stream_messags:
|
||||
yield ModelOutput(text=msg, error_code=0)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
for out in self.generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
return self._embeddings
|
||||
|
||||
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
|
||||
def _new_worker_params(
|
||||
model_name: str = _TEST_MODEL_NAME,
|
||||
model_path: str = _TEST_MODEL_PATH,
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
) -> ModelWorkerParameters:
|
||||
return ModelWorkerParameters(
|
||||
model_name=model_name, model_path=model_path, worker_type=worker_type
|
||||
)
|
||||
|
||||
|
||||
def _create_workers(
|
||||
num_workers: int,
|
||||
error_worker: bool = False,
|
||||
stop_error: bool = False,
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
model_name = f"test-model-name-{i}"
|
||||
model_path = f"test-model-path-{i}"
|
||||
model_parameters = ModelParameters(model_name=model_name, model_path=model_path)
|
||||
worker = MockModelWorker(
|
||||
model_parameters,
|
||||
error_worker=error_worker,
|
||||
stop_error=stop_error,
|
||||
stream_messags=stream_messags,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
worker_params = _new_worker_params(
|
||||
model_name, model_path, worker_type=worker_type
|
||||
)
|
||||
workers.append((worker, worker_params))
|
||||
return workers
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _start_worker_manager(**kwargs):
|
||||
register_func = kwargs.get("register_func")
|
||||
deregister_func = kwargs.get("deregister_func")
|
||||
send_heartbeat_func = kwargs.get("send_heartbeat_func")
|
||||
model_registry = kwargs.get("model_registry")
|
||||
workers = kwargs.get("workers")
|
||||
num_workers = int(kwargs.get("num_workers", 0))
|
||||
start = kwargs.get("start", True)
|
||||
stop = kwargs.get("stop", True)
|
||||
error_worker = kwargs.get("error_worker", False)
|
||||
stop_error = kwargs.get("stop_error", False)
|
||||
stream_messags = kwargs.get("stream_messags", [])
|
||||
embeddings = kwargs.get("embeddings", [])
|
||||
|
||||
worker_manager = LocalWorkerManager(
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
model_registry=model_registry,
|
||||
)
|
||||
|
||||
for worker, worker_params in _create_workers(
|
||||
num_workers, error_worker, stop_error, stream_messags, embeddings
|
||||
):
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
if workers:
|
||||
for worker, worker_params in workers:
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
if start:
|
||||
await worker_manager.start()
|
||||
|
||||
yield worker_manager
|
||||
if stop:
|
||||
await worker_manager.stop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
async with _start_worker_manager(num_workers=2, **param) as worker_manager:
|
||||
yield worker_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_with_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(2, stream_messags=param.get("stream_messags", []))
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_embedding_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(
|
||||
2, worker_type=WorkerType.TEXT2VEC.value, embeddings=param.get("embeddings", [])
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
486
pilot/model/cluster/worker/tests/test_manager.py
Normal file
486
pilot/model/cluster/worker/tests/test_manager.py
Normal file
@ -0,0 +1,486 @@
|
||||
from unittest.mock import patch, AsyncMock
|
||||
import pytest
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from dataclasses import asdict
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType
|
||||
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.manager_base import WorkerRunData
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
from pilot.model.cluster.worker.tests.base_tests import (
|
||||
MockModelWorker,
|
||||
manager_2_workers,
|
||||
manager_with_2_workers,
|
||||
manager_2_embedding_workers,
|
||||
_create_workers,
|
||||
_start_worker_manager,
|
||||
_new_worker_params,
|
||||
)
|
||||
|
||||
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker():
|
||||
mock_worker = _create_workers(1)
|
||||
yield mock_worker[0][0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker_param():
|
||||
return _new_worker_params()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(request):
|
||||
if not request or not hasattr(request, "param"):
|
||||
register_func = None
|
||||
deregister_func = None
|
||||
send_heartbeat_func = None
|
||||
model_registry = None
|
||||
workers = []
|
||||
else:
|
||||
register_func = request.param.get("register_func")
|
||||
deregister_func = request.param.get("deregister_func")
|
||||
send_heartbeat_func = request.param.get("send_heartbeat_func")
|
||||
model_registry = request.param.get("model_registry")
|
||||
workers = request.param.get("model_registry")
|
||||
|
||||
worker_manager = LocalWorkerManager(
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
model_registry=model_registry,
|
||||
)
|
||||
yield worker_manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_blocking_func(manager: LocalWorkerManager):
|
||||
def f1() -> int:
|
||||
return 0
|
||||
|
||||
def f2(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
async def error_f3() -> None:
|
||||
return 0
|
||||
|
||||
assert await manager.run_blocking_func(f1) == 0
|
||||
assert await manager.run_blocking_func(f2, 1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
await manager.run_blocking_func(error_f3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_worker(
|
||||
manager: LocalWorkerManager,
|
||||
worker: ModelWorker,
|
||||
worker_param: ModelWorkerParameters,
|
||||
):
|
||||
# TODO test with register function
|
||||
assert manager.add_worker(worker, worker_param)
|
||||
# Add again
|
||||
assert manager.add_worker(worker, worker_param) == False
|
||||
key = manager._worker_key(worker_param.worker_type, worker_param.model_name)
|
||||
assert len(manager.workers) == 1
|
||||
assert len(manager.workers[key]) == 1
|
||||
assert manager.workers[key][0].worker == worker
|
||||
|
||||
assert manager.add_worker(
|
||||
worker,
|
||||
_new_worker_params(
|
||||
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
|
||||
),
|
||||
)
|
||||
assert (
|
||||
manager.add_worker(
|
||||
worker,
|
||||
_new_worker_params(
|
||||
model_name="chatglm2-6b", model_path="/app/models/chatglm2-6b"
|
||||
),
|
||||
)
|
||||
== False
|
||||
)
|
||||
assert len(manager.workers) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__apply_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
|
||||
async def f1(wr: WorkerRunData) -> int:
|
||||
return 0
|
||||
|
||||
# Apply to all workers
|
||||
assert await manager._apply_worker(None, apply_func=f1) == [0, 0]
|
||||
|
||||
workers = _create_workers(4)
|
||||
async with _start_worker_manager(workers=workers) as manager:
|
||||
# Apply to single model
|
||||
req = WorkerApplyRequest(
|
||||
model=workers[0][1].model_name,
|
||||
apply_type=WorkerApplyType.START,
|
||||
worker_type=WorkerType.LLM,
|
||||
)
|
||||
assert await manager._apply_worker(req, apply_func=f1) == [0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("manager_2_workers", [{"start": False}], indirect=True)
|
||||
async def test__start_all_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
out = await manager._start_all_worker(None)
|
||||
assert out.success
|
||||
assert len(manager.workers) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_error_worker",
|
||||
[
|
||||
({"start": False, "error_worker": False}, False),
|
||||
({"start": False, "error_worker": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test_start_worker_manager(
|
||||
manager_2_workers: LocalWorkerManager, is_error_worker: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
if is_error_worker:
|
||||
with pytest.raises(Exception):
|
||||
await manager.start()
|
||||
else:
|
||||
await manager.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_stop_error",
|
||||
[
|
||||
({"stop": False, "stop_error": False}, False),
|
||||
({"stop": False, "stop_error": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test__stop_all_worker(
|
||||
manager_2_workers: LocalWorkerManager, is_stop_error: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
out = await manager._stop_all_worker(None)
|
||||
if is_stop_error:
|
||||
assert not out.success
|
||||
else:
|
||||
assert out.success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__restart_all_worker(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
out = await manager._restart_all_worker(None)
|
||||
assert out.success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_workers, is_stop_error",
|
||||
[
|
||||
({"stop": False, "stop_error": False}, False),
|
||||
({"stop": False, "stop_error": True}, True),
|
||||
],
|
||||
indirect=["manager_2_workers"],
|
||||
)
|
||||
async def test_stop_worker_manager(
|
||||
manager_2_workers: LocalWorkerManager, is_stop_error: bool
|
||||
):
|
||||
manager = manager_2_workers
|
||||
if is_stop_error:
|
||||
with pytest.raises(Exception):
|
||||
await manager.stop()
|
||||
else:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__remove_worker():
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
manager._remove_worker(worker_params)
|
||||
not_exist_parmas = _new_worker_params(
|
||||
model_name="this is a not exist worker params"
|
||||
)
|
||||
manager._remove_worker(not_exist_parmas)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pilot.model.cluster.worker.manager._build_worker")
|
||||
async def test_model_startup(mock_build_worker):
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
await manager.model_startup(req)
|
||||
with pytest.raises(Exception):
|
||||
await manager.model_startup(req)
|
||||
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1, error_worker=True)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
await manager.model_startup(req)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pilot.model.cluster.worker.manager._build_worker")
|
||||
async def test_model_shutdown(mock_build_worker):
|
||||
async with _start_worker_manager(start=False, stop=False) as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
port=8001,
|
||||
model=worker_params.model_name,
|
||||
worker_type=WorkerType.LLM,
|
||||
params=asdict(worker_params),
|
||||
)
|
||||
await manager.model_startup(req)
|
||||
await manager.model_shutdown(req)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_models(manager_2_workers: LocalWorkerManager):
|
||||
manager = manager_2_workers
|
||||
models = await manager.supported_models()
|
||||
assert len(models) == 1
|
||||
models = models[0].models
|
||||
assert len(models) > 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test_get_model_instances(is_async):
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
assert (
|
||||
len(await manager.get_model_instances(worker_type, model_name)) == 1
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
len(manager.sync_get_model_instances(worker_type, model_name)) == 1
|
||||
)
|
||||
if is_async:
|
||||
assert not await manager.get_model_instances(
|
||||
worker_type, "this is not exist model instances"
|
||||
)
|
||||
else:
|
||||
assert not manager.sync_get_model_instances(
|
||||
worker_type, "this is not exist model instances"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__simple_select(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
instances = await manager.get_model_instances(worker_type, model_name)
|
||||
assert instances
|
||||
inst = manager._simple_select(worker_params.worker_type, model_name, instances)
|
||||
assert inst is not None
|
||||
assert inst.worker_params == worker_params
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test_select_one_instance(
|
||||
is_async: bool,
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
inst = await manager.select_one_instance(worker_type, model_name)
|
||||
else:
|
||||
inst = manager.sync_select_one_instance(worker_type, model_name)
|
||||
assert inst is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"is_async",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test__get_model(
|
||||
is_async: bool,
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
if is_async:
|
||||
wr = await manager._get_model(params, worker_type=worker_type)
|
||||
else:
|
||||
wr = manager._sync_get_model(params, worker_type=worker_type)
|
||||
assert wr is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
async def test_generate_stream(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
text = ""
|
||||
async for out in manager.generate_stream(params):
|
||||
text += out.text
|
||||
assert text == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, " world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
async def test_generate(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
out = await manager.generate(params)
|
||||
assert out.text == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"manager_2_embedding_workers, expected_embedding, is_async",
|
||||
[
|
||||
({"embeddings": [[1, 2, 3], [4, 5, 6]]}, [[1, 2, 3], [4, 5, 6]], True),
|
||||
({"embeddings": [[0, 0, 0], [1, 1, 1]]}, [[0, 0, 0], [1, 1, 1]], False),
|
||||
],
|
||||
indirect=["manager_2_embedding_workers"],
|
||||
)
|
||||
async def test_embeddings(
|
||||
manager_2_embedding_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
],
|
||||
expected_embedding: List[List[int]],
|
||||
is_async: bool,
|
||||
):
|
||||
manager, workers = manager_2_embedding_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name, "input": ["hello", "world"]}
|
||||
if is_async:
|
||||
out = await manager.embeddings(params)
|
||||
else:
|
||||
out = manager.sync_embeddings(params)
|
||||
assert out == expected_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parameter_descriptions(
|
||||
manager_with_2_workers: Tuple[
|
||||
LocalWorkerManager, List[Tuple[ModelWorker, ModelWorkerParameters]]
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = await manager.parameter_descriptions(worker_type, model_name)
|
||||
assert params is not None
|
||||
assert len(params) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__update_all_worker_params():
|
||||
# TODO
|
||||
pass
|
@ -46,6 +46,18 @@ class ModelControllerParameters(BaseParameters):
|
||||
],
|
||||
},
|
||||
)
|
||||
log_file: Optional[str] = field(
|
||||
default="dbgpt_model_controller.log",
|
||||
metadata={
|
||||
"help": "The filename to store log",
|
||||
},
|
||||
)
|
||||
tracer_file: Optional[str] = field(
|
||||
default="dbgpt_model_controller_tracer.jsonl",
|
||||
metadata={
|
||||
"help": "The filename to store tracer span records",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -122,6 +134,18 @@ class ModelWorkerParameters(BaseModelParameters):
|
||||
],
|
||||
},
|
||||
)
|
||||
log_file: Optional[str] = field(
|
||||
default="dbgpt_model_worker_manager.log",
|
||||
metadata={
|
||||
"help": "The filename to store log",
|
||||
},
|
||||
)
|
||||
tracer_file: Optional[str] = field(
|
||||
default="dbgpt_model_worker_manager_tracer.jsonl",
|
||||
metadata={
|
||||
"help": "The filename to store tracer span records",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -95,6 +95,19 @@ class WebWerverParameters(BaseParameters):
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Webserver in background"}
|
||||
)
|
||||
controller_addr: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`."
|
||||
},
|
||||
)
|
||||
model_name: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
share: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -123,3 +136,15 @@ class WebWerverParameters(BaseParameters):
|
||||
},
|
||||
)
|
||||
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
||||
log_file: Optional[str] = field(
|
||||
default="dbgpt_webserver.log",
|
||||
metadata={
|
||||
"help": "The filename to store log",
|
||||
},
|
||||
)
|
||||
tracer_file: Optional[str] = field(
|
||||
default="dbgpt_webserver_tracer.jsonl",
|
||||
metadata={
|
||||
"help": "The filename to store tracer span records",
|
||||
},
|
||||
)
|
||||
|
@ -119,7 +119,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
if not param.log_level:
|
||||
param.log_level = _get_logging_level()
|
||||
setup_logging(
|
||||
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
|
||||
"pilot", logging_level=param.log_level, logger_filename=param.log_file
|
||||
)
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
@ -133,14 +133,16 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||
model_name = param.model_name or CFG.LLM_MODEL
|
||||
|
||||
model_path = LLM_MODEL_CONFIG.get(model_name)
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
embedding_model_name, embedding_model_path = None, None
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
local_port=param.port,
|
||||
embedding_model_name=embedding_model_name,
|
||||
@ -152,12 +154,13 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
else:
|
||||
# MODEL_SERVER is controller address now
|
||||
controller_addr = param.controller_addr or CFG.MODEL_SERVER
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
run_locally=False,
|
||||
controller_addr=CFG.MODEL_SERVER,
|
||||
controller_addr=controller_addr,
|
||||
local_port=param.port,
|
||||
start_listener=model_start_listener,
|
||||
system_app=system_app,
|
||||
@ -182,7 +185,7 @@ def run_uvicorn(param: WebWerverParameters):
|
||||
def run_webserver(param: WebWerverParameters = None):
|
||||
if not param:
|
||||
param = _get_webserver_params()
|
||||
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
|
||||
initialize_tracer(system_app, os.path.join(LOGDIR, param.tracer_file))
|
||||
|
||||
with root_tracer.start_span(
|
||||
"run_webserver",
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
@ -27,6 +28,13 @@ class FileSpanStorage(SpanStorage):
|
||||
def __init__(self, filename: str, batch_size=10, flush_interval=10):
|
||||
super().__init__()
|
||||
self.filename = filename
|
||||
# Split filename into prefix and suffix
|
||||
self.filename_prefix, self.filename_suffix = os.path.splitext(filename)
|
||||
if not self.filename_suffix:
|
||||
self.filename_suffix = ".log"
|
||||
self.last_date = (
|
||||
datetime.datetime.now().date()
|
||||
) # Store the current date for checking date changes
|
||||
self.queue = queue.Queue()
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
@ -52,7 +60,21 @@ class FileSpanStorage(SpanStorage):
|
||||
except queue.Full:
|
||||
pass # If the signal queue is full, it's okay. The flush thread will handle it.
|
||||
|
||||
def _get_dated_filename(self, date: datetime.date) -> str:
|
||||
"""Return the filename based on a specific date."""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
return f"{self.filename_prefix}_{date_str}{self.filename_suffix}"
|
||||
|
||||
def _roll_over_if_needed(self):
|
||||
"""Checks if a day has changed since the last write, and if so, renames the current file."""
|
||||
current_date = datetime.datetime.now().date()
|
||||
if current_date != self.last_date:
|
||||
if os.path.exists(self.filename):
|
||||
os.rename(self.filename, self._get_dated_filename(self.last_date))
|
||||
self.last_date = current_date
|
||||
|
||||
def _write_to_file(self):
|
||||
self._roll_over_if_needed()
|
||||
spans_to_write = []
|
||||
while not self.queue.empty():
|
||||
spans_to_write.append(self.queue.get())
|
||||
|
@ -4,6 +4,8 @@ import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from pilot.utils.tracer import SpanStorage, FileSpanStorage, Span, SpanType
|
||||
|
||||
@ -122,3 +124,46 @@ def test_non_existent_file(storage: SpanStorage):
|
||||
assert len(spans_in_file) == 2
|
||||
assert spans_in_file[0]["trace_id"] == "1"
|
||||
assert spans_in_file[1]["trace_id"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"storage", [{"batch_size": 1, "file_does_not_exist": True}], indirect=True
|
||||
)
|
||||
def test_log_rollover(storage: SpanStorage):
|
||||
# mock start date
|
||||
mock_start_date = datetime(2023, 10, 18, 23, 59)
|
||||
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = mock_start_date
|
||||
|
||||
span1 = Span("1", "a", SpanType.BASE, "b", "op1")
|
||||
storage.append_span(span1)
|
||||
time.sleep(0.1)
|
||||
|
||||
# mock new day
|
||||
mock_datetime.now.return_value = mock_start_date + timedelta(minutes=1)
|
||||
|
||||
span2 = Span("2", "c", SpanType.BASE, "d", "op2")
|
||||
storage.append_span(span2)
|
||||
time.sleep(0.1)
|
||||
|
||||
# origin filename need exists
|
||||
assert os.path.exists(storage.filename)
|
||||
|
||||
# get roll over filename
|
||||
dated_filename = os.path.join(
|
||||
os.path.dirname(storage.filename),
|
||||
f"{os.path.basename(storage.filename).split('.')[0]}_2023-10-18.jsonl",
|
||||
)
|
||||
|
||||
assert os.path.exists(dated_filename)
|
||||
|
||||
# check origin filename just include the second span
|
||||
spans_in_original_file = read_spans_from_file(storage.filename)
|
||||
assert len(spans_in_original_file) == 1
|
||||
assert spans_in_original_file[0]["trace_id"] == "2"
|
||||
|
||||
# check the roll over filename just include the first span
|
||||
spans_in_dated_file = read_spans_from_file(dated_filename)
|
||||
assert len(spans_in_dated_file) == 1
|
||||
assert spans_in_dated_file[0]["trace_id"] == "1"
|
||||
|
Loading…
Reference in New Issue
Block a user