mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(core): New command line for analyze and visualize trace spans
This commit is contained in:
parent
1e919aeef3
commit
2a46909eac
@ -15,7 +15,7 @@ IMAGE_NAME_ARGS=""
|
||||
PIP_INDEX_URL="https://pypi.org/simple"
|
||||
# en or zh
|
||||
LANGUAGE="en"
|
||||
BUILD_LOCAL_CODE="false"
|
||||
BUILD_LOCAL_CODE="true"
|
||||
LOAD_EXAMPLES="true"
|
||||
BUILD_NETWORK=""
|
||||
DB_GPT_INSTALL_MODEL="default"
|
||||
@ -26,7 +26,7 @@ usage () {
|
||||
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
||||
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
||||
echo " [--language en or zh] You language, default: en"
|
||||
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
|
||||
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true"
|
||||
echo " [--load-examples true or false] Whether to load examples to default database default: true"
|
||||
echo " [--network network name] The network of docker build"
|
||||
echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'"
|
||||
|
@ -18,11 +18,13 @@ class PromptRequest(BaseModel):
|
||||
max_new_tokens: int = None
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
span_id: str = None
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
span_id: str = None
|
||||
|
||||
|
||||
class WorkerApplyRequest(BaseModel):
|
||||
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pilot.model.parameter import BaseEmbeddingModelParameters
|
||||
from pilot.utils.parameter_utils import _get_dict_from_obj
|
||||
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -15,13 +17,21 @@ class EmbeddingLoader:
|
||||
def load(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> "Embeddings":
|
||||
# add more models
|
||||
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
"params": _get_dict_from_obj(param),
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata
|
||||
):
|
||||
# add more models
|
||||
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(**param.build_kwargs())
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
return OpenAIEmbeddings(**param.build_kwargs())
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
@ -9,8 +9,8 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.tracer import root_tracer
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -95,9 +95,20 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
logger.info(f"Begin load model, model params: {model_params}")
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||
model_params, self.llm_adapter
|
||||
)
|
||||
metadata = {
|
||||
"model_name": self.model_name,
|
||||
"model_path": self.model_path,
|
||||
"model_type": self.llm_adapter.model_type(),
|
||||
"llm_adapter": str(self.llm_adapter),
|
||||
"run_service": SpanTypeRunName.MODEL_WORKER,
|
||||
"params": _get_dict_from_obj(model_params),
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata
|
||||
):
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||
model_params, self.llm_adapter
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
@ -110,7 +121,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
span = root_tracer.start_span("DefaultModelWorker.generate_stream")
|
||||
span = root_tracer.start_span(
|
||||
"DefaultModelWorker.generate_stream", params.get("span_id")
|
||||
)
|
||||
try:
|
||||
(
|
||||
params,
|
||||
@ -153,7 +166,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
span = root_tracer.start_span("DefaultModelWorker.async_generate_stream")
|
||||
span = root_tracer.start_span(
|
||||
"DefaultModelWorker.async_generate_stream", params.get("span_id")
|
||||
)
|
||||
try:
|
||||
(
|
||||
params,
|
||||
|
@ -8,12 +8,13 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
ModelOutput,
|
||||
@ -35,8 +36,10 @@ from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
_dict_to_command_args,
|
||||
_get_dict_from_obj,
|
||||
)
|
||||
from pilot.utils.utils import setup_logging
|
||||
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -293,60 +296,72 @@ class LocalWorkerManager(WorkerManager):
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
return
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
async for outout in worker_run_data.worker.async_generate_stream(
|
||||
params
|
||||
):
|
||||
yield outout
|
||||
else:
|
||||
if not async_wrapper:
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.generate_stream", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
yield ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
return
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
async for outout in worker_run_data.worker.async_generate_stream(
|
||||
params
|
||||
):
|
||||
yield outout
|
||||
else:
|
||||
if not async_wrapper:
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
async_wrapper = iterate_in_threadpool
|
||||
async for output in async_wrapper(
|
||||
worker_run_data.worker.generate_stream(params)
|
||||
):
|
||||
yield output
|
||||
async_wrapper = iterate_in_threadpool
|
||||
async for output in async_wrapper(
|
||||
worker_run_data.worker.generate_stream(params)
|
||||
):
|
||||
yield output
|
||||
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_generate(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.generate, params
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.generate", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params)
|
||||
except Exception as e:
|
||||
return ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_generate(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.generate, params
|
||||
)
|
||||
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
try:
|
||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||
except Exception as e:
|
||||
raise e
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_embeddings(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.embeddings, params
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager.embeddings", params.get("span_id")
|
||||
) as span:
|
||||
params["span_id"] = span.span_id
|
||||
try:
|
||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||
except Exception as e:
|
||||
raise e
|
||||
async with worker_run_data.semaphore:
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_embeddings(params)
|
||||
else:
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.embeddings, params
|
||||
)
|
||||
|
||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
||||
@ -608,6 +623,9 @@ async def generate_json_stream(params):
|
||||
@router.post("/worker/generate_stream")
|
||||
async def api_generate_stream(request: PromptRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
generator = generate_json_stream(params)
|
||||
return StreamingResponse(generator)
|
||||
|
||||
@ -615,12 +633,18 @@ async def api_generate_stream(request: PromptRequest):
|
||||
@router.post("/worker/generate")
|
||||
async def api_generate(request: PromptRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.generate(params)
|
||||
|
||||
|
||||
@router.post("/worker/embeddings")
|
||||
async def api_embeddings(request: EmbeddingsRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
span_id = root_tracer.get_current_span_id()
|
||||
if "span_id" not in params and span_id:
|
||||
params["span_id"] = span_id
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
@ -801,10 +825,18 @@ def _build_worker(worker_params: ModelWorkerParameters):
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager._start_local_worker",
|
||||
span_type=SpanType.RUN,
|
||||
metadata={
|
||||
"run_service": SpanTypeRunName.WORKER_MANAGER,
|
||||
"params": _get_dict_from_obj(worker_params),
|
||||
},
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
|
||||
def _start_local_embedding_worker(
|
||||
@ -928,17 +960,17 @@ def run_worker_manager(
|
||||
# Run worker manager independently
|
||||
embedded_mod = False
|
||||
app = _setup_fastapi(worker_params)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
else:
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(worker_manager.start())
|
||||
|
||||
system_app = SystemApp(app)
|
||||
initialize_tracer(
|
||||
system_app,
|
||||
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
|
||||
root_operation_name="DB-GPT-WorkerManager-Entry",
|
||||
)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
||||
if include_router:
|
||||
app.include_router(router, prefix="/api")
|
||||
@ -946,6 +978,8 @@ def run_worker_manager(
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(worker_manager.start())
|
||||
uvicorn.run(
|
||||
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
||||
)
|
||||
|
@ -46,7 +46,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
from pilot.model.base import FlatSupportedModel
|
||||
from pilot.utils.tracer import root_tracer
|
||||
from pilot.utils.tracer import root_tracer, SpanType
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -367,7 +367,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(
|
||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||
)
|
||||
with root_tracer.start_span("chat_completions", metadata=dialogue.dict()) as _:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||
):
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
# background_tasks = BackgroundTasks()
|
||||
# background_tasks.add_task(release_model_semaphore)
|
||||
@ -419,11 +421,10 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana
|
||||
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
span = root_tracer.start_span("no_stream_generator")
|
||||
msg = await chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
span.end()
|
||||
with root_tracer.start_span("no_stream_generator"):
|
||||
msg = await chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
|
||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
@ -467,9 +468,9 @@ async def stream_generator(chat, incremental: bool, model_name: str):
|
||||
yield f"data:{msg}\n\n"
|
||||
previous_response = msg
|
||||
await asyncio.sleep(0.02)
|
||||
span.end()
|
||||
if incremental:
|
||||
yield "data: [DONE]\n\n"
|
||||
span.end()
|
||||
chat.current_message.add_ai_message(msg)
|
||||
chat.current_message.add_view_message(msg)
|
||||
chat.memory.append(chat.current_message)
|
||||
|
@ -139,7 +139,9 @@ class BaseChat(ABC):
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
del metadata["prompt"]
|
||||
metadata["messages"] = list(map(lambda m: m.dict(), metadata["messages"]))
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
async def stream_call(self):
|
||||
@ -152,6 +154,7 @@ class BaseChat(ABC):
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
@ -178,6 +181,7 @@ class BaseChat(ABC):
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
@ -185,7 +189,7 @@ class BaseChat(ABC):
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
|
||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate") as _:
|
||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
### output parse
|
||||
@ -206,7 +210,7 @@ class BaseChat(ABC):
|
||||
"ai_response_text": ai_response_text,
|
||||
"prompt_define_response": prompt_define_response,
|
||||
}
|
||||
with root_tracer.start_span("BaseChat.do_action", metadata=metadata) as _:
|
||||
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
||||
### run
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
|
@ -119,6 +119,14 @@ except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt knowledge command line tool failed: {e}")
|
||||
|
||||
|
||||
try:
|
||||
from pilot.utils.tracer.tracer_cli import trace_cli_group
|
||||
|
||||
add_command_alias(trace_cli_group, name="trace", parent_group=cli)
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt trace command line tool failed: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
||||
|
@ -8,8 +8,6 @@ from pilot.component import ComponentType, SystemApp
|
||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from pilot.server.base import WebWerverParameters
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils.tracer import root_tracer, initialize_tracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -26,8 +24,6 @@ def initialize_components(
|
||||
):
|
||||
from pilot.model.cluster.controller.controller import controller
|
||||
|
||||
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
|
||||
|
||||
# Register global default executor factory first
|
||||
system_app.register(DefaultExecutorFactory)
|
||||
|
||||
|
@ -7,7 +7,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
import signal
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR
|
||||
from pilot.component import SystemApp
|
||||
|
||||
from pilot.server.base import (
|
||||
@ -38,6 +38,8 @@ from pilot.utils.utils import (
|
||||
_get_logging_level,
|
||||
logging_str_to_uvicorn_level,
|
||||
)
|
||||
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
|
||||
from pilot.utils.parameter_utils import _get_dict_from_obj
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
@ -98,17 +100,21 @@ def mount_static_files(app):
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
||||
def _get_webserver_params(args: List[str] = None):
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
|
||||
WebWerverParameters
|
||||
)
|
||||
return WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
|
||||
|
||||
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
"""Initialize app
|
||||
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
|
||||
"""
|
||||
if not param:
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
|
||||
WebWerverParameters
|
||||
)
|
||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
param = _get_webserver_params(args)
|
||||
|
||||
if not param.log_level:
|
||||
param.log_level = _get_logging_level()
|
||||
@ -127,7 +133,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
@ -174,8 +180,20 @@ def run_uvicorn(param: WebWerverParameters):
|
||||
|
||||
|
||||
def run_webserver(param: WebWerverParameters = None):
|
||||
param = initialize_app(param)
|
||||
run_uvicorn(param)
|
||||
if not param:
|
||||
param = _get_webserver_params()
|
||||
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
|
||||
|
||||
with root_tracer.start_span(
|
||||
"run_webserver",
|
||||
span_type=SpanType.RUN,
|
||||
metadata={
|
||||
"run_service": SpanTypeRunName.WEBSERVER,
|
||||
"params": _get_dict_from_obj(param),
|
||||
},
|
||||
):
|
||||
param = initialize_app(param)
|
||||
run_uvicorn(param)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -13,7 +13,7 @@ from pilot.model.cluster import run_worker_manager
|
||||
|
||||
CFG = Config()
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_worker_manager(
|
||||
|
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING, asdict, field
|
||||
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
@ -590,6 +590,20 @@ def _extract_parameter_details(
|
||||
return descriptions
|
||||
|
||||
|
||||
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||
if not obj:
|
||||
return None
|
||||
if is_dataclass(type(obj)):
|
||||
params = {}
|
||||
for field_info in fields(obj):
|
||||
value = _get_simple_privacy_field_value(obj, field_info)
|
||||
params[field_info.name] = value
|
||||
return params
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
return default_value
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
|
@ -1,5 +1,7 @@
|
||||
from pilot.utils.tracer.base import (
|
||||
SpanType,
|
||||
Span,
|
||||
SpanTypeRunName,
|
||||
Tracer,
|
||||
SpanStorage,
|
||||
SpanStorageType,
|
||||
@ -14,7 +16,9 @@ from pilot.utils.tracer.tracer_impl import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpanType",
|
||||
"Span",
|
||||
"SpanTypeRunName",
|
||||
"Tracer",
|
||||
"SpanStorage",
|
||||
"SpanStorageType",
|
||||
|
@ -10,22 +10,41 @@ from datetime import datetime
|
||||
from pilot.component import BaseComponent, SystemApp, ComponentType
|
||||
|
||||
|
||||
class SpanType(str, Enum):
|
||||
BASE = "base"
|
||||
RUN = "run"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class SpanTypeRunName(str, Enum):
|
||||
WEBSERVER = "Webserver"
|
||||
WORKER_MANAGER = "WorkerManager"
|
||||
MODEL_WORKER = "ModelWorker"
|
||||
EMBEDDING_MODEL = "EmbeddingModel"
|
||||
|
||||
@staticmethod
|
||||
def values():
|
||||
return [item.value for item in SpanTypeRunName]
|
||||
|
||||
|
||||
class Span:
|
||||
"""Represents a unit of work that is being traced.
|
||||
This can be any operation like a function call or a database query.
|
||||
"""
|
||||
|
||||
span_type: str = "base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_id: str,
|
||||
span_id: str,
|
||||
span_type: SpanType = None,
|
||||
parent_span_id: str = None,
|
||||
operation_name: str = None,
|
||||
metadata: Dict = None,
|
||||
end_caller: Callable[[Span], None] = None,
|
||||
):
|
||||
if not span_type:
|
||||
span_type = SpanType.BASE
|
||||
self.span_type = span_type
|
||||
# The unique identifier for the entire trace
|
||||
self.trace_id = trace_id
|
||||
# Unique identifier for this span within the trace
|
||||
@ -65,7 +84,7 @@ class Span:
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"span_type": self.span_type,
|
||||
"span_type": self.span_type.value,
|
||||
"trace_id": self.trace_id,
|
||||
"span_id": self.span_id,
|
||||
"parent_span_id": self.parent_span_id,
|
||||
@ -124,7 +143,11 @@ class Tracer(BaseComponent, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def start_span(
|
||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
"""Begin a new span for the given operation. If provided, the span will be
|
||||
a child of the span with the given parent_span_id.
|
||||
@ -158,4 +181,4 @@ class Tracer(BaseComponent, ABC):
|
||||
|
||||
@dataclass
|
||||
class TracerContext:
|
||||
span_id: str
|
||||
span_id: Optional[str] = None
|
||||
|
540
pilot/utils/tracer/tracer_cli.py
Normal file
540
pilot/utils/tracer/tracer_cli.py
Normal file
@ -0,0 +1,540 @@
|
||||
import os
|
||||
import click
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Dict, Callable
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils.tracer import SpanType, SpanTypeRunName
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
|
||||
|
||||
|
||||
@click.group("trace")
|
||||
def trace_cli_group():
|
||||
"""Analyze and visualize trace spans."""
|
||||
pass
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the trace ID to list",
|
||||
)
|
||||
@click.option(
|
||||
"--span_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Span ID to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--span_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Span Type to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--parent_span_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Specify the Parent Span ID to list.",
|
||||
)
|
||||
@click.option(
|
||||
"--search",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
|
||||
)
|
||||
@click.option(
|
||||
"-l",
|
||||
"--limit",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Limit the number of recent span displayed.",
|
||||
)
|
||||
@click.option(
|
||||
"--start_time",
|
||||
type=str,
|
||||
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
|
||||
)
|
||||
@click.option(
|
||||
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
|
||||
)
|
||||
@click.option(
|
||||
"--desc",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Whether to use reverse sorting. By default, sorting is based on start time.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||
def list(
|
||||
trace_id: str,
|
||||
span_id: str,
|
||||
span_type: str,
|
||||
parent_span_id: str,
|
||||
search: str,
|
||||
limit: int,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
desc: bool,
|
||||
output: str,
|
||||
files=None,
|
||||
):
|
||||
"""List your trace spans"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
# If no files are explicitly specified, use the default pattern to get them
|
||||
spans = read_spans_from_files(files)
|
||||
|
||||
if trace_id:
|
||||
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
|
||||
if span_id:
|
||||
spans = filter(lambda s: s["span_id"] == span_id, spans)
|
||||
if span_type:
|
||||
spans = filter(lambda s: s["span_type"] == span_type, spans)
|
||||
if parent_span_id:
|
||||
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
|
||||
# Filter spans based on the start and end times
|
||||
if start_time:
|
||||
start_dt = _parse_datetime(start_time)
|
||||
spans = filter(
|
||||
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
|
||||
)
|
||||
|
||||
if end_time:
|
||||
end_dt = _parse_datetime(end_time)
|
||||
spans = filter(
|
||||
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
|
||||
)
|
||||
|
||||
if search:
|
||||
spans = filter(_new_search_span_func(search), spans)
|
||||
|
||||
# Sort spans based on the start time
|
||||
spans = sorted(
|
||||
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
|
||||
)[:limit]
|
||||
|
||||
table = PrettyTable(
|
||||
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
|
||||
)
|
||||
|
||||
for sp in spans:
|
||||
conv_uid = None
|
||||
if "metadata" in sp and sp:
|
||||
metadata = sp["metadata"]
|
||||
if isinstance(metadata, dict):
|
||||
conv_uid = metadata.get("conv_uid")
|
||||
table.add_row(
|
||||
[
|
||||
sp.get("trace_id"),
|
||||
sp.get("span_id"),
|
||||
# sp.get("parent_span_id"),
|
||||
sp.get("operation_name"),
|
||||
conv_uid,
|
||||
]
|
||||
)
|
||||
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Specify the trace ID to list",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||
def tree(trace_id: str, files):
|
||||
"""Display trace links as a tree"""
|
||||
hierarchy = _view_trace_hierarchy(trace_id, files)
|
||||
_print_trace_hierarchy(hierarchy)
|
||||
|
||||
|
||||
@trace_cli_group.command()
|
||||
@click.option(
|
||||
"--trace_id",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the trace ID to analyze. If None, show latest conversation details",
|
||||
)
|
||||
@click.option(
|
||||
"--tree",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Display trace spans as a tree",
|
||||
)
|
||||
@click.option(
|
||||
"--hide_run_params",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Hide run params",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
|
||||
def chat(trace_id: str, tree: bool, hide_run_params: bool, output: str, files):
|
||||
"""Show conversation details"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
spans = read_spans_from_files(files)
|
||||
|
||||
# Sort by start time
|
||||
spans = sorted(
|
||||
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
|
||||
)
|
||||
spans = [sp for sp in spans]
|
||||
service_spans = {}
|
||||
service_names = set(SpanTypeRunName.values())
|
||||
found_trace_id = None
|
||||
for sp in spans:
|
||||
span_type = sp["span_type"]
|
||||
metadata = sp.get("metadata")
|
||||
if span_type == SpanType.RUN:
|
||||
service_name = metadata["run_service"]
|
||||
service_spans[service_name] = sp.copy()
|
||||
if set(service_spans.keys()) == service_names and found_trace_id:
|
||||
break
|
||||
elif span_type == SpanType.CHAT and not found_trace_id:
|
||||
if not trace_id:
|
||||
found_trace_id = sp["trace_id"]
|
||||
if trace_id and trace_id == sp["trace_id"]:
|
||||
found_trace_id = trace_id
|
||||
|
||||
service_tables = {}
|
||||
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||
for service_name, sp in service_spans.items():
|
||||
metadata = sp["metadata"]
|
||||
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
|
||||
for k, v in metadata["params"].items():
|
||||
table.add_row([k, v])
|
||||
service_tables[service_name] = table
|
||||
|
||||
if not hide_run_params:
|
||||
merged_table1 = merge_tables_horizontally(
|
||||
[
|
||||
service_tables.get(SpanTypeRunName.WEBSERVER.value),
|
||||
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
|
||||
]
|
||||
)
|
||||
merged_table2 = merge_tables_horizontally(
|
||||
[
|
||||
service_tables.get(SpanTypeRunName.MODEL_WORKER),
|
||||
service_tables.get(SpanTypeRunName.WORKER_MANAGER),
|
||||
]
|
||||
)
|
||||
if output == "text":
|
||||
print(merged_table1)
|
||||
print(merged_table2)
|
||||
else:
|
||||
for service_name, table in service_tables.items():
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
if not found_trace_id:
|
||||
print(f"Can't found conversation with trace_id: {trace_id}")
|
||||
return
|
||||
trace_id = found_trace_id
|
||||
|
||||
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||
trace_spans = [s for s in reversed(trace_spans)]
|
||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||
if tree:
|
||||
print("\nInvoke Trace Tree:\n")
|
||||
_print_trace_hierarchy(hierarchy)
|
||||
|
||||
trace_spans = _get_ordered_trace_from(hierarchy)
|
||||
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
||||
split_long_text = output == "text"
|
||||
|
||||
for sp in trace_spans:
|
||||
op = sp["operation_name"]
|
||||
metadata = sp.get("metadata")
|
||||
if op == "get_chat_instance" and not sp["end_time"]:
|
||||
table.add_row(["trace_id", trace_id])
|
||||
table.add_row(["span_id", sp["span_id"]])
|
||||
table.add_row(["conv_uid", metadata.get("conv_uid")])
|
||||
table.add_row(["user_input", metadata.get("user_input")])
|
||||
table.add_row(["chat_mode", metadata.get("chat_mode")])
|
||||
table.add_row(["select_param", metadata.get("select_param")])
|
||||
table.add_row(["model_name", metadata.get("model_name")])
|
||||
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
|
||||
if not sp["end_time"]:
|
||||
table.add_row(["temperature", metadata.get("temperature")])
|
||||
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
|
||||
table.add_row(["echo", metadata.get("echo")])
|
||||
elif "error" in metadata:
|
||||
table.add_row(["BaseChat Error", metadata.get("error")])
|
||||
if op == "BaseChat.nostream_call" and not sp["end_time"]:
|
||||
if "model_output" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat model_output",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("model_output").get("text"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if "ai_response_text" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat ai_response_text",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("ai_response_text"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
if "prompt_define_response" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat prompt_define_response",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("prompt_define_response"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if op == "DefaultModelWorker_call.generate_stream_func":
|
||||
if not sp["end_time"]:
|
||||
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
|
||||
table.add_row(
|
||||
[
|
||||
"User prompt",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("prompt"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
[
|
||||
"Model output",
|
||||
split_string_by_terminal_width(metadata.get("output")),
|
||||
]
|
||||
)
|
||||
if (
|
||||
op
|
||||
in [
|
||||
"DefaultModelWorker.async_generate_stream",
|
||||
"DefaultModelWorker.generate_stream",
|
||||
]
|
||||
and metadata
|
||||
and "error" in metadata
|
||||
):
|
||||
table.add_row(["Model Error", metadata.get("error")])
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
|
||||
|
||||
def read_spans_from_files(files=None) -> Iterable[Dict]:
|
||||
"""
|
||||
Reads spans from multiple files based on the provided file paths.
|
||||
"""
|
||||
if not files:
|
||||
files = [_DEFAULT_FILE_PATTERN]
|
||||
|
||||
for filepath in files:
|
||||
for filename in glob.glob(filepath):
|
||||
with open(filename, "r") as file:
|
||||
for line in file:
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def _new_search_span_func(search: str):
|
||||
def func(span: Dict) -> bool:
|
||||
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
|
||||
if "operation_name" in span:
|
||||
items.append(span["operation_name"])
|
||||
if "metadata" in span:
|
||||
metadata = span["metadata"]
|
||||
if isinstance(metadata, dict):
|
||||
for k, v in metadata.items():
|
||||
items.append(k)
|
||||
items.append(v)
|
||||
return any(search in str(item) for item in items if item)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _parse_datetime(dt_str):
|
||||
"""Parse a datetime string to a datetime object."""
|
||||
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
|
||||
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
|
||||
# Current spans
|
||||
current_level_spans = [
|
||||
span
|
||||
for span in spans
|
||||
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
|
||||
]
|
||||
|
||||
hierarchy = []
|
||||
|
||||
for start_span in current_level_spans:
|
||||
# Find end span
|
||||
end_span = next(
|
||||
(
|
||||
span
|
||||
for span in spans
|
||||
if span["span_id"] == start_span["span_id"]
|
||||
and span["end_time"] is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
entry = {
|
||||
"operation_name": start_span["operation_name"],
|
||||
"parent_span_id": start_span["parent_span_id"],
|
||||
"span_id": start_span["span_id"],
|
||||
"start_time": start_span["start_time"],
|
||||
"end_time": start_span["end_time"],
|
||||
"metadata": start_span["metadata"],
|
||||
"children": _build_trace_hierarchy(
|
||||
spans, start_span["span_id"], indent + 1
|
||||
),
|
||||
}
|
||||
hierarchy.append(entry)
|
||||
|
||||
# Append end span
|
||||
if end_span:
|
||||
entry_end = {
|
||||
"operation_name": end_span["operation_name"],
|
||||
"parent_span_id": end_span["parent_span_id"],
|
||||
"span_id": end_span["span_id"],
|
||||
"start_time": end_span["start_time"],
|
||||
"end_time": end_span["end_time"],
|
||||
"metadata": end_span["metadata"],
|
||||
"children": [],
|
||||
}
|
||||
hierarchy.append(entry_end)
|
||||
|
||||
return hierarchy
|
||||
|
||||
|
||||
def _view_trace_hierarchy(trace_id, files=None):
|
||||
"""Find and display the calls of the entire link based on the given trace_id"""
|
||||
spans = read_spans_from_files(files)
|
||||
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||
return hierarchy
|
||||
|
||||
|
||||
def _print_trace_hierarchy(hierarchy, indent=0):
|
||||
"""Print link hierarchy"""
|
||||
for entry in hierarchy:
|
||||
print(
|
||||
" " * indent
|
||||
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
|
||||
)
|
||||
_print_trace_hierarchy(entry["children"], indent + 1)
|
||||
|
||||
|
||||
def _get_ordered_trace_from(hierarchy):
|
||||
traces = []
|
||||
|
||||
def func(items):
|
||||
for item in items:
|
||||
traces.append(item)
|
||||
func(item["children"])
|
||||
|
||||
func(hierarchy)
|
||||
return traces
|
||||
|
||||
|
||||
def _print(service_spans: Dict):
|
||||
for names in [
|
||||
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
|
||||
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
def merge_tables_horizontally(tables):
|
||||
from prettytable import PrettyTable
|
||||
|
||||
if not tables:
|
||||
return None
|
||||
|
||||
tables = [t for t in tables if t]
|
||||
if not tables:
|
||||
return None
|
||||
|
||||
max_rows = max(len(table._rows) for table in tables)
|
||||
|
||||
merged_table = PrettyTable()
|
||||
|
||||
new_field_names = []
|
||||
for table in tables:
|
||||
new_field_names.extend(
|
||||
[
|
||||
f"{name} ({table.title})" if table.title else f"{name}"
|
||||
for name in table.field_names
|
||||
]
|
||||
)
|
||||
|
||||
merged_table.field_names = new_field_names
|
||||
|
||||
for i in range(max_rows):
|
||||
merged_row = []
|
||||
for table in tables:
|
||||
if i < len(table._rows):
|
||||
merged_row.extend(table._rows[i])
|
||||
else:
|
||||
# Fill empty cells for shorter tables
|
||||
merged_row.extend([""] * len(table.field_names))
|
||||
merged_table.add_row(merged_row)
|
||||
|
||||
return merged_table
|
||||
|
||||
|
||||
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
|
||||
"""
|
||||
Split a string into substrings based on the current terminal width.
|
||||
|
||||
Parameters:
|
||||
- s: the input string
|
||||
"""
|
||||
if not split:
|
||||
return s
|
||||
if not max_len:
|
||||
try:
|
||||
max_len = int(os.get_terminal_size().columns * 0.8)
|
||||
except OSError:
|
||||
# Default to 80 columns if the terminal size can't be determined
|
||||
max_len = 100
|
||||
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])
|
@ -4,6 +4,7 @@ from functools import wraps
|
||||
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.tracer.base import (
|
||||
SpanType,
|
||||
Span,
|
||||
Tracer,
|
||||
SpanStorage,
|
||||
@ -32,14 +33,23 @@ class DefaultTracer(Tracer):
|
||||
self._get_current_storage().append_span(span)
|
||||
|
||||
def start_span(
|
||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
trace_id = (
|
||||
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
|
||||
)
|
||||
span_id = f"{trace_id}:{self._new_uuid()}"
|
||||
span = Span(
|
||||
trace_id, span_id, parent_span_id, operation_name, metadata=metadata
|
||||
trace_id,
|
||||
span_id,
|
||||
span_type,
|
||||
parent_span_id,
|
||||
operation_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if self._span_storage_type in [
|
||||
@ -81,11 +91,13 @@ class DefaultTracer(Tracer):
|
||||
|
||||
|
||||
class TracerManager:
|
||||
"""The manager of current tracer"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._system_app: Optional[SystemApp] = None
|
||||
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
|
||||
"trace_context",
|
||||
default=TracerContext(span_id="default_trace_id:default_span_id"),
|
||||
default=TracerContext(),
|
||||
)
|
||||
|
||||
def initialize(
|
||||
@ -101,18 +113,23 @@ class TracerManager:
|
||||
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
|
||||
|
||||
def start_span(
|
||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
||||
self,
|
||||
operation_name: str,
|
||||
parent_span_id: str = None,
|
||||
span_type: SpanType = None,
|
||||
metadata: Dict = None,
|
||||
) -> Span:
|
||||
"""Start a new span with operation_name
|
||||
This method must not throw an exception under any case and try not to block as much as possible
|
||||
"""
|
||||
tracer = self._get_tracer()
|
||||
if not tracer:
|
||||
return Span("empty_span", "empty_span")
|
||||
if not parent_span_id:
|
||||
current_span = self.get_current_span()
|
||||
if current_span:
|
||||
parent_span_id = current_span.span_id
|
||||
else:
|
||||
parent_span_id = self._trace_context_var.get().span_id
|
||||
return tracer.start_span(operation_name, parent_span_id, metadata)
|
||||
parent_span_id = self.get_current_span_id()
|
||||
return tracer.start_span(
|
||||
operation_name, parent_span_id, span_type=span_type, metadata=metadata
|
||||
)
|
||||
|
||||
def end_span(self, span: Span, **kwargs):
|
||||
tracer = self._get_tracer()
|
||||
@ -126,15 +143,22 @@ class TracerManager:
|
||||
return None
|
||||
return tracer.get_current_span()
|
||||
|
||||
def get_current_span_id(self) -> Optional[str]:
|
||||
current_span = self.get_current_span()
|
||||
if current_span:
|
||||
return current_span.span_id
|
||||
ctx = self._trace_context_var.get()
|
||||
return ctx.span_id if ctx else None
|
||||
|
||||
|
||||
root_tracer: TracerManager = TracerManager()
|
||||
|
||||
|
||||
def trace(operation_name: str):
|
||||
def trace(operation_name: str, **trace_kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
with root_tracer.start_span(operation_name) as span:
|
||||
with root_tracer.start_span(operation_name, **trace_kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -142,14 +166,18 @@ def trace(operation_name: str):
|
||||
return decorator
|
||||
|
||||
|
||||
def initialize_tracer(system_app: SystemApp, tracer_filename: str):
|
||||
def initialize_tracer(
|
||||
system_app: SystemApp,
|
||||
tracer_filename: str,
|
||||
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||
):
|
||||
if not system_app:
|
||||
return
|
||||
from pilot.utils.tracer.span_storage import FileSpanStorage
|
||||
|
||||
trace_context_var = ContextVar(
|
||||
"trace_context",
|
||||
default=TracerContext(span_id="default_trace_id:default_span_id"),
|
||||
default=TracerContext(),
|
||||
)
|
||||
tracer = DefaultTracer(system_app)
|
||||
|
||||
@ -160,5 +188,8 @@ def initialize_tracer(system_app: SystemApp, tracer_filename: str):
|
||||
from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware
|
||||
|
||||
system_app.app.add_middleware(
|
||||
TraceIDMiddleware, trace_context_var=trace_context_var, tracer=tracer
|
||||
TraceIDMiddleware,
|
||||
trace_context_var=trace_context_var,
|
||||
tracer=tracer,
|
||||
root_operation_name=root_operation_name,
|
||||
)
|
||||
|
@ -16,12 +16,14 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
|
||||
app: ASGIApp,
|
||||
trace_context_var: ContextVar[TracerContext],
|
||||
tracer: Tracer,
|
||||
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||
include_prefix: str = "/api",
|
||||
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.trace_context_var = trace_context_var
|
||||
self.tracer = tracer
|
||||
self.root_operation_name = root_operation_name
|
||||
self.include_prefix = include_prefix
|
||||
self.exclude_paths = exclude_paths
|
||||
|
||||
@ -37,7 +39,7 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
|
||||
# self.trace_context_var.set(TracerContext(span_id=span_id))
|
||||
|
||||
with self.tracer.start_span(
|
||||
"DB-GPT-Web-Entry", span_id, metadata={"path": request.url.path}
|
||||
) as _:
|
||||
self.root_operation_name, span_id, metadata={"path": request.url.path}
|
||||
):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
Loading…
Reference in New Issue
Block a user