feat(core): New command line for analyze and visualize trace spans

This commit is contained in:
FangYin Cheng 2023-10-12 01:18:01 +08:00
parent 1e919aeef3
commit 2a46909eac
17 changed files with 830 additions and 128 deletions

View File

@ -15,7 +15,7 @@ IMAGE_NAME_ARGS=""
PIP_INDEX_URL="https://pypi.org/simple"
# en or zh
LANGUAGE="en"
BUILD_LOCAL_CODE="false"
BUILD_LOCAL_CODE="true"
LOAD_EXAMPLES="true"
BUILD_NETWORK=""
DB_GPT_INSTALL_MODEL="default"
@ -26,7 +26,7 @@ usage () {
echo " [-n|--image-name image name] Current image name, default: db-gpt"
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
echo " [--language en or zh] You language, default: en"
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true"
echo " [--load-examples true or false] Whether to load examples to default database default: true"
echo " [--network network name] The network of docker build"
echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'"

View File

@ -18,11 +18,13 @@ class PromptRequest(BaseModel):
max_new_tokens: int = None
stop: str = None
echo: bool = True
span_id: str = None
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
span_id: str = None
class WorkerApplyRequest(BaseModel):

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
from pilot.utils.parameter_utils import _get_dict_from_obj
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -15,13 +17,21 @@ class EmbeddingLoader:
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
}
with root_tracer.start_span(
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata
):
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@ -9,8 +9,8 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.tracer import root_tracer
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
logger = logging.getLogger(__name__)
@ -95,9 +95,20 @@ class DefaultModelWorker(ModelWorker):
model_params = self.parse_parameters(command_args)
self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}")
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
metadata = {
"model_name": self.model_name,
"model_path": self.model_path,
"model_type": self.llm_adapter.model_type(),
"llm_adapter": str(self.llm_adapter),
"run_service": SpanTypeRunName.MODEL_WORKER,
"params": _get_dict_from_obj(model_params),
}
with root_tracer.start_span(
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata
):
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
def stop(self) -> None:
if not self.model:
@ -110,7 +121,9 @@ class DefaultModelWorker(ModelWorker):
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span("DefaultModelWorker.generate_stream")
span = root_tracer.start_span(
"DefaultModelWorker.generate_stream", params.get("span_id")
)
try:
(
params,
@ -153,7 +166,9 @@ class DefaultModelWorker(ModelWorker):
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
span = root_tracer.start_span("DefaultModelWorker.async_generate_stream")
span = root_tracer.start_span(
"DefaultModelWorker.async_generate_stream", params.get("span_id")
)
try:
(
params,

View File

@ -8,12 +8,13 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from typing import Awaitable, Callable, Dict, Iterator, List
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,
ModelOutput,
@ -35,8 +36,10 @@ from pilot.utils.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
_dict_to_command_args,
_get_dict_from_obj,
)
from pilot.utils.utils import setup_logging
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
logger = logging.getLogger(__name__)
@ -293,60 +296,72 @@ class LocalWorkerManager(WorkerManager):
self, params: Dict, async_wrapper=None, **kwargs
) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
try:
worker_run_data = await self._get_model(params)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
async for outout in worker_run_data.worker.async_generate_stream(
params
):
yield outout
else:
if not async_wrapper:
from starlette.concurrency import iterate_in_threadpool
with root_tracer.start_span(
"WorkerManager.generate_stream", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
yield ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
async for outout in worker_run_data.worker.async_generate_stream(
params
):
yield outout
else:
if not async_wrapper:
from starlette.concurrency import iterate_in_threadpool
async_wrapper = iterate_in_threadpool
async for output in async_wrapper(
worker_run_data.worker.generate_stream(params)
):
yield output
async_wrapper = iterate_in_threadpool
async for output in async_wrapper(
worker_run_data.worker.generate_stream(params)
):
yield output
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
try:
worker_run_data = await self._get_model(params)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.generate, params
with root_tracer.start_span(
"WorkerManager.generate", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.generate, params
)
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
try:
worker_run_data = await self._get_model(params, worker_type="text2vec")
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.embeddings, params
)
with root_tracer.start_span(
"WorkerManager.embeddings", params.get("span_id")
) as span:
params["span_id"] = span.span_id
try:
worker_run_data = await self._get_model(params, worker_type="text2vec")
except Exception as e:
raise e
async with worker_run_data.semaphore:
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
else:
return await self.run_blocking_func(
worker_run_data.worker.embeddings, params
)
def sync_embeddings(self, params: Dict) -> List[List[float]]:
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
@ -608,6 +623,9 @@ async def generate_json_stream(params):
@router.post("/worker/generate_stream")
async def api_generate_stream(request: PromptRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
generator = generate_json_stream(params)
return StreamingResponse(generator)
@ -615,12 +633,18 @@ async def api_generate_stream(request: PromptRequest):
@router.post("/worker/generate")
async def api_generate(request: PromptRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.generate(params)
@router.post("/worker/embeddings")
async def api_embeddings(request: EmbeddingsRequest):
params = request.dict(exclude_none=True)
span_id = root_tracer.get_current_span_id()
if "span_id" not in params and span_id:
params["span_id"] = span_id
return await worker_manager.embeddings(params)
@ -801,10 +825,18 @@ def _build_worker(worker_params: ModelWorkerParameters):
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
):
worker = _build_worker(worker_params)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
with root_tracer.start_span(
"WorkerManager._start_local_worker",
span_type=SpanType.RUN,
metadata={
"run_service": SpanTypeRunName.WORKER_MANAGER,
"params": _get_dict_from_obj(worker_params),
},
):
worker = _build_worker(worker_params)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
def _start_local_embedding_worker(
@ -928,17 +960,17 @@ def run_worker_manager(
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
else:
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
system_app = SystemApp(app)
initialize_tracer(
system_app,
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
root_operation_name="DB-GPT-WorkerManager-Entry",
)
_start_local_worker(worker_manager, worker_params)
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
if include_router:
app.include_router(router, prefix="/api")
@ -946,6 +978,8 @@ def run_worker_manager(
if not embedded_mod:
import uvicorn
loop = asyncio.get_event_loop()
loop.run_until_complete(worker_manager.start())
uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info"
)

View File

@ -46,7 +46,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer
from pilot.utils.tracer import root_tracer, SpanType
router = APIRouter()
CFG = Config()
@ -367,7 +367,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
with root_tracer.start_span("chat_completions", metadata=dialogue.dict()) as _:
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
@ -419,11 +421,10 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana
async def no_stream_generator(chat):
span = root_tracer.start_span("no_stream_generator")
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
span.end()
with root_tracer.start_span("no_stream_generator"):
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat, incremental: bool, model_name: str):
@ -467,9 +468,9 @@ async def stream_generator(chat, incremental: bool, model_name: str):
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
span.end()
if incremental:
yield "data: [DONE]\n\n"
span.end()
chat.current_message.add_ai_message(msg)
chat.current_message.add_view_message(msg)
chat.memory.append(chat.current_message)

View File

@ -139,7 +139,9 @@ class BaseChat(ABC):
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(map(lambda m: m.dict(), metadata["messages"]))
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
async def stream_call(self):
@ -152,6 +154,7 @@ class BaseChat(ABC):
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@ -178,6 +181,7 @@ class BaseChat(ABC):
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@ -185,7 +189,7 @@ class BaseChat(ABC):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate") as _:
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await worker_manager.generate(payload)
### output parse
@ -206,7 +210,7 @@ class BaseChat(ABC):
"ai_response_text": ai_response_text,
"prompt_define_response": prompt_define_response,
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata) as _:
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run
result = self.do_action(prompt_define_response)

View File

@ -119,6 +119,14 @@ except ImportError as e:
logging.warning(f"Integrating dbgpt knowledge command line tool failed: {e}")
try:
from pilot.utils.tracer.tracer_cli import trace_cli_group
add_command_alias(trace_cli_group, name="trace", parent_group=cli)
except ImportError as e:
logging.warning(f"Integrating dbgpt trace command line tool failed: {e}")
def main():
return cli()

View File

@ -8,8 +8,6 @@ from pilot.component import ComponentType, SystemApp
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters
from pilot.configs.model_config import LOGDIR
from pilot.utils.tracer import root_tracer, initialize_tracer
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -26,8 +24,6 @@ def initialize_components(
):
from pilot.model.cluster.controller.controller import controller
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
# Register global default executor factory first
system_app.register(DefaultExecutorFactory)

View File

@ -7,7 +7,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR
from pilot.component import SystemApp
from pilot.server.base import (
@ -38,6 +38,8 @@ from pilot.utils.utils import (
_get_logging_level,
logging_str_to_uvicorn_level,
)
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
from pilot.utils.parameter_utils import _get_dict_from_obj
static_file_path = os.path.join(os.getcwd(), "server/static")
@ -98,17 +100,21 @@ def mount_static_files(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def _get_webserver_params(args: List[str] = None):
from pilot.utils.parameter_utils import EnvArgumentParser
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters
)
return WebWerverParameters(**vars(parser.parse_args(args=args)))
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
"""Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
"""
if not param:
from pilot.utils.parameter_utils import EnvArgumentParser
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters
)
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
param = _get_webserver_params(args)
if not param.log_level:
param.log_level = _get_logging_level()
@ -127,7 +133,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
@ -174,8 +180,20 @@ def run_uvicorn(param: WebWerverParameters):
def run_webserver(param: WebWerverParameters = None):
param = initialize_app(param)
run_uvicorn(param)
if not param:
param = _get_webserver_params()
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
with root_tracer.start_span(
"run_webserver",
span_type=SpanType.RUN,
metadata={
"run_service": SpanTypeRunName.WEBSERVER,
"params": _get_dict_from_obj(param),
},
):
param = initialize_app(param)
run_uvicorn(param)
if __name__ == "__main__":

View File

@ -13,7 +13,7 @@ from pilot.model.cluster import run_worker_manager
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
if __name__ == "__main__":
run_worker_manager(

View File

@ -1,6 +1,6 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING, asdict, field
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
from typing import Any, List, Optional, Type, Union, Callable, Dict
from collections import OrderedDict
@ -590,6 +590,20 @@ def _extract_parameter_details(
return descriptions
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
if not obj:
return None
if is_dataclass(type(obj)):
params = {}
for field_info in fields(obj):
value = _get_simple_privacy_field_value(obj, field_info)
params[field_info.name] = value
return params
if isinstance(obj, dict):
return obj
return default_value
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}

View File

@ -1,5 +1,7 @@
from pilot.utils.tracer.base import (
SpanType,
Span,
SpanTypeRunName,
Tracer,
SpanStorage,
SpanStorageType,
@ -14,7 +16,9 @@ from pilot.utils.tracer.tracer_impl import (
)
__all__ = [
"SpanType",
"Span",
"SpanTypeRunName",
"Tracer",
"SpanStorage",
"SpanStorageType",

View File

@ -10,22 +10,41 @@ from datetime import datetime
from pilot.component import BaseComponent, SystemApp, ComponentType
class SpanType(str, Enum):
BASE = "base"
RUN = "run"
CHAT = "chat"
class SpanTypeRunName(str, Enum):
WEBSERVER = "Webserver"
WORKER_MANAGER = "WorkerManager"
MODEL_WORKER = "ModelWorker"
EMBEDDING_MODEL = "EmbeddingModel"
@staticmethod
def values():
return [item.value for item in SpanTypeRunName]
class Span:
"""Represents a unit of work that is being traced.
This can be any operation like a function call or a database query.
"""
span_type: str = "base"
def __init__(
self,
trace_id: str,
span_id: str,
span_type: SpanType = None,
parent_span_id: str = None,
operation_name: str = None,
metadata: Dict = None,
end_caller: Callable[[Span], None] = None,
):
if not span_type:
span_type = SpanType.BASE
self.span_type = span_type
# The unique identifier for the entire trace
self.trace_id = trace_id
# Unique identifier for this span within the trace
@ -65,7 +84,7 @@ class Span:
def to_dict(self) -> Dict:
return {
"span_type": self.span_type,
"span_type": self.span_type.value,
"trace_id": self.trace_id,
"span_id": self.span_id,
"parent_span_id": self.parent_span_id,
@ -124,7 +143,11 @@ class Tracer(BaseComponent, ABC):
@abstractmethod
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Begin a new span for the given operation. If provided, the span will be
a child of the span with the given parent_span_id.
@ -158,4 +181,4 @@ class Tracer(BaseComponent, ABC):
@dataclass
class TracerContext:
span_id: str
span_id: Optional[str] = None

View File

@ -0,0 +1,540 @@
import os
import click
import logging
import glob
import json
from datetime import datetime
from typing import Iterable, Dict, Callable
from pilot.configs.model_config import LOGDIR
from pilot.utils.tracer import SpanType, SpanTypeRunName
logger = logging.getLogger("dbgpt_cli")
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
@click.group("trace")
def trace_cli_group():
"""Analyze and visualize trace spans."""
pass
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the trace ID to list",
)
@click.option(
"--span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span ID to list.",
)
@click.option(
"--span_type",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Span Type to list.",
)
@click.option(
"--parent_span_id",
required=False,
type=str,
default=None,
show_default=True,
help="Specify the Parent Span ID to list.",
)
@click.option(
"--search",
required=False,
type=str,
default=None,
show_default=True,
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
)
@click.option(
"-l",
"--limit",
type=int,
default=20,
help="Limit the number of recent span displayed.",
)
@click.option(
"--start_time",
type=str,
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
)
@click.option(
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
)
@click.option(
"--desc",
required=False,
type=bool,
default=False,
is_flag=True,
help="Whether to use reverse sorting. By default, sorting is based on start time.",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def list(
trace_id: str,
span_id: str,
span_type: str,
parent_span_id: str,
search: str,
limit: int,
start_time: str,
end_time: str,
desc: bool,
output: str,
files=None,
):
"""List your trace spans"""
from prettytable import PrettyTable
# If no files are explicitly specified, use the default pattern to get them
spans = read_spans_from_files(files)
if trace_id:
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
if span_id:
spans = filter(lambda s: s["span_id"] == span_id, spans)
if span_type:
spans = filter(lambda s: s["span_type"] == span_type, spans)
if parent_span_id:
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
# Filter spans based on the start and end times
if start_time:
start_dt = _parse_datetime(start_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
)
if end_time:
end_dt = _parse_datetime(end_time)
spans = filter(
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
)
if search:
spans = filter(_new_search_span_func(search), spans)
# Sort spans based on the start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
)[:limit]
table = PrettyTable(
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
)
for sp in spans:
conv_uid = None
if "metadata" in sp and sp:
metadata = sp["metadata"]
if isinstance(metadata, dict):
conv_uid = metadata.get("conv_uid")
table.add_row(
[
sp.get("trace_id"),
sp.get("span_id"),
# sp.get("parent_span_id"),
sp.get("operation_name"),
conv_uid,
]
)
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
print(table.get_formatted_string(out_format=output, **out_kwargs))
@trace_cli_group.command()
@click.option(
"--trace_id",
required=True,
type=str,
help="Specify the trace ID to list",
)
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
def tree(trace_id: str, files):
"""Display trace links as a tree"""
hierarchy = _view_trace_hierarchy(trace_id, files)
_print_trace_hierarchy(hierarchy)
@trace_cli_group.command()
@click.option(
"--trace_id",
required=False,
type=str,
default=None,
help="Specify the trace ID to analyze. If None, show latest conversation details",
)
@click.option(
"--tree",
required=False,
type=bool,
default=False,
is_flag=True,
help="Display trace spans as a tree",
)
@click.option(
"--hide_run_params",
required=False,
type=bool,
default=False,
is_flag=True,
help="Hide run params",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
def chat(trace_id: str, tree: bool, hide_run_params: bool, output: str, files):
"""Show conversation details"""
from prettytable import PrettyTable
spans = read_spans_from_files(files)
# Sort by start time
spans = sorted(
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
)
spans = [sp for sp in spans]
service_spans = {}
service_names = set(SpanTypeRunName.values())
found_trace_id = None
for sp in spans:
span_type = sp["span_type"]
metadata = sp.get("metadata")
if span_type == SpanType.RUN:
service_name = metadata["run_service"]
service_spans[service_name] = sp.copy()
if set(service_spans.keys()) == service_names and found_trace_id:
break
elif span_type == SpanType.CHAT and not found_trace_id:
if not trace_id:
found_trace_id = sp["trace_id"]
if trace_id and trace_id == sp["trace_id"]:
found_trace_id = trace_id
service_tables = {}
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
for service_name, sp in service_spans.items():
metadata = sp["metadata"]
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
for k, v in metadata["params"].items():
table.add_row([k, v])
service_tables[service_name] = table
if not hide_run_params:
merged_table1 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.WEBSERVER.value),
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
]
)
merged_table2 = merge_tables_horizontally(
[
service_tables.get(SpanTypeRunName.MODEL_WORKER),
service_tables.get(SpanTypeRunName.WORKER_MANAGER),
]
)
if output == "text":
print(merged_table1)
print(merged_table2)
else:
for service_name, table in service_tables.items():
print(table.get_formatted_string(out_format=output, **out_kwargs))
if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}")
return
trace_id = found_trace_id
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans)
if tree:
print("\nInvoke Trace Tree:\n")
_print_trace_hierarchy(hierarchy)
trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text"
for sp in trace_spans:
op = sp["operation_name"]
metadata = sp.get("metadata")
if op == "get_chat_instance" and not sp["end_time"]:
table.add_row(["trace_id", trace_id])
table.add_row(["span_id", sp["span_id"]])
table.add_row(["conv_uid", metadata.get("conv_uid")])
table.add_row(["user_input", metadata.get("user_input")])
table.add_row(["chat_mode", metadata.get("chat_mode")])
table.add_row(["select_param", metadata.get("select_param")])
table.add_row(["model_name", metadata.get("model_name")])
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
if not sp["end_time"]:
table.add_row(["temperature", metadata.get("temperature")])
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")])
if op == "BaseChat.nostream_call" and not sp["end_time"]:
if "model_output" in metadata:
table.add_row(
[
"BaseChat model_output",
split_string_by_terminal_width(
metadata.get("model_output").get("text"),
split=split_long_text,
),
]
)
if "ai_response_text" in metadata:
table.add_row(
[
"BaseChat ai_response_text",
split_string_by_terminal_width(
metadata.get("ai_response_text"), split=split_long_text
),
]
)
if "prompt_define_response" in metadata:
table.add_row(
[
"BaseChat prompt_define_response",
split_string_by_terminal_width(
metadata.get("prompt_define_response"),
split=split_long_text,
),
]
)
if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
table.add_row(
[
"User prompt",
split_string_by_terminal_width(
metadata.get("prompt"), split=split_long_text
),
]
)
else:
table.add_row(
[
"Model output",
split_string_by_terminal_width(metadata.get("output")),
]
)
if (
op
in [
"DefaultModelWorker.async_generate_stream",
"DefaultModelWorker.generate_stream",
]
and metadata
and "error" in metadata
):
table.add_row(["Model Error", metadata.get("error")])
print(table.get_formatted_string(out_format=output, **out_kwargs))
def read_spans_from_files(files=None) -> Iterable[Dict]:
"""
Reads spans from multiple files based on the provided file paths.
"""
if not files:
files = [_DEFAULT_FILE_PATTERN]
for filepath in files:
for filename in glob.glob(filepath):
with open(filename, "r") as file:
for line in file:
yield json.loads(line)
def _new_search_span_func(search: str):
def func(span: Dict) -> bool:
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
if "operation_name" in span:
items.append(span["operation_name"])
if "metadata" in span:
metadata = span["metadata"]
if isinstance(metadata, dict):
for k, v in metadata.items():
items.append(k)
items.append(v)
return any(search in str(item) for item in items if item)
return func
def _parse_datetime(dt_str):
"""Parse a datetime string to a datetime object."""
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
# Current spans
current_level_spans = [
span
for span in spans
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
]
hierarchy = []
for start_span in current_level_spans:
# Find end span
end_span = next(
(
span
for span in spans
if span["span_id"] == start_span["span_id"]
and span["end_time"] is not None
),
None,
)
entry = {
"operation_name": start_span["operation_name"],
"parent_span_id": start_span["parent_span_id"],
"span_id": start_span["span_id"],
"start_time": start_span["start_time"],
"end_time": start_span["end_time"],
"metadata": start_span["metadata"],
"children": _build_trace_hierarchy(
spans, start_span["span_id"], indent + 1
),
}
hierarchy.append(entry)
# Append end span
if end_span:
entry_end = {
"operation_name": end_span["operation_name"],
"parent_span_id": end_span["parent_span_id"],
"span_id": end_span["span_id"],
"start_time": end_span["start_time"],
"end_time": end_span["end_time"],
"metadata": end_span["metadata"],
"children": [],
}
hierarchy.append(entry_end)
return hierarchy
def _view_trace_hierarchy(trace_id, files=None):
"""Find and display the calls of the entire link based on the given trace_id"""
spans = read_spans_from_files(files)
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
hierarchy = _build_trace_hierarchy(trace_spans)
return hierarchy
def _print_trace_hierarchy(hierarchy, indent=0):
"""Print link hierarchy"""
for entry in hierarchy:
print(
" " * indent
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
)
_print_trace_hierarchy(entry["children"], indent + 1)
def _get_ordered_trace_from(hierarchy):
traces = []
def func(items):
for item in items:
traces.append(item)
func(item["children"])
func(hierarchy)
return traces
def _print(service_spans: Dict):
for names in [
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
]:
pass
def merge_tables_horizontally(tables):
from prettytable import PrettyTable
if not tables:
return None
tables = [t for t in tables if t]
if not tables:
return None
max_rows = max(len(table._rows) for table in tables)
merged_table = PrettyTable()
new_field_names = []
for table in tables:
new_field_names.extend(
[
f"{name} ({table.title})" if table.title else f"{name}"
for name in table.field_names
]
)
merged_table.field_names = new_field_names
for i in range(max_rows):
merged_row = []
for table in tables:
if i < len(table._rows):
merged_row.extend(table._rows[i])
else:
# Fill empty cells for shorter tables
merged_row.extend([""] * len(table.field_names))
merged_table.add_row(merged_row)
return merged_table
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
"""
Split a string into substrings based on the current terminal width.
Parameters:
- s: the input string
"""
if not split:
return s
if not max_len:
try:
max_len = int(os.get_terminal_size().columns * 0.8)
except OSError:
# Default to 80 columns if the terminal size can't be determined
max_len = 100
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])

View File

@ -4,6 +4,7 @@ from functools import wraps
from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import (
SpanType,
Span,
Tracer,
SpanStorage,
@ -32,14 +33,23 @@ class DefaultTracer(Tracer):
self._get_current_storage().append_span(span)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
trace_id = (
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
)
span_id = f"{trace_id}:{self._new_uuid()}"
span = Span(
trace_id, span_id, parent_span_id, operation_name, metadata=metadata
trace_id,
span_id,
span_type,
parent_span_id,
operation_name,
metadata=metadata,
)
if self._span_storage_type in [
@ -81,11 +91,13 @@ class DefaultTracer(Tracer):
class TracerManager:
"""The manager of current tracer"""
def __init__(self) -> None:
self._system_app: Optional[SystemApp] = None
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
"trace_context",
default=TracerContext(span_id="default_trace_id:default_span_id"),
default=TracerContext(),
)
def initialize(
@ -101,18 +113,23 @@ class TracerManager:
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
def start_span(
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
self,
operation_name: str,
parent_span_id: str = None,
span_type: SpanType = None,
metadata: Dict = None,
) -> Span:
"""Start a new span with operation_name
This method must not throw an exception under any case and try not to block as much as possible
"""
tracer = self._get_tracer()
if not tracer:
return Span("empty_span", "empty_span")
if not parent_span_id:
current_span = self.get_current_span()
if current_span:
parent_span_id = current_span.span_id
else:
parent_span_id = self._trace_context_var.get().span_id
return tracer.start_span(operation_name, parent_span_id, metadata)
parent_span_id = self.get_current_span_id()
return tracer.start_span(
operation_name, parent_span_id, span_type=span_type, metadata=metadata
)
def end_span(self, span: Span, **kwargs):
tracer = self._get_tracer()
@ -126,15 +143,22 @@ class TracerManager:
return None
return tracer.get_current_span()
def get_current_span_id(self) -> Optional[str]:
current_span = self.get_current_span()
if current_span:
return current_span.span_id
ctx = self._trace_context_var.get()
return ctx.span_id if ctx else None
root_tracer: TracerManager = TracerManager()
def trace(operation_name: str):
def trace(operation_name: str, **trace_kwargs):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
with root_tracer.start_span(operation_name) as span:
with root_tracer.start_span(operation_name, **trace_kwargs):
return await func(*args, **kwargs)
return wrapper
@ -142,14 +166,18 @@ def trace(operation_name: str):
return decorator
def initialize_tracer(system_app: SystemApp, tracer_filename: str):
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,
root_operation_name: str = "DB-GPT-Web-Entry",
):
if not system_app:
return
from pilot.utils.tracer.span_storage import FileSpanStorage
trace_context_var = ContextVar(
"trace_context",
default=TracerContext(span_id="default_trace_id:default_span_id"),
default=TracerContext(),
)
tracer = DefaultTracer(system_app)
@ -160,5 +188,8 @@ def initialize_tracer(system_app: SystemApp, tracer_filename: str):
from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware
system_app.app.add_middleware(
TraceIDMiddleware, trace_context_var=trace_context_var, tracer=tracer
TraceIDMiddleware,
trace_context_var=trace_context_var,
tracer=tracer,
root_operation_name=root_operation_name,
)

View File

@ -16,12 +16,14 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
app: ASGIApp,
trace_context_var: ContextVar[TracerContext],
tracer: Tracer,
root_operation_name: str = "DB-GPT-Web-Entry",
include_prefix: str = "/api",
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
):
super().__init__(app)
self.trace_context_var = trace_context_var
self.tracer = tracer
self.root_operation_name = root_operation_name
self.include_prefix = include_prefix
self.exclude_paths = exclude_paths
@ -37,7 +39,7 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
# self.trace_context_var.set(TracerContext(span_id=span_id))
with self.tracer.start_span(
"DB-GPT-Web-Entry", span_id, metadata={"path": request.url.path}
) as _:
self.root_operation_name, span_id, metadata={"path": request.url.path}
):
response = await call_next(request)
return response