From 2a46909eac5bdaf87ac17a7fe84bcf65b73949c1 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Thu, 12 Oct 2023 01:18:01 +0800 Subject: [PATCH] feat(core): New command line for analyze and visualize trace spans --- docker/base/build_image.sh | 4 +- pilot/model/cluster/base.py | 2 + pilot/model/cluster/embedding/loader.py | 26 +- pilot/model/cluster/worker/default_worker.py | 29 +- pilot/model/cluster/worker/manager.py | 158 +++--- pilot/openapi/api_v1/api_v1.py | 17 +- pilot/scene/base_chat.py | 10 +- pilot/scripts/cli_scripts.py | 8 + pilot/server/component_configs.py | 4 - pilot/server/dbgpt_server.py | 38 +- pilot/server/llmserver.py | 2 +- pilot/utils/parameter_utils.py | 16 +- pilot/utils/tracer/__init__.py | 4 + pilot/utils/tracer/base.py | 33 +- pilot/utils/tracer/tracer_cli.py | 540 +++++++++++++++++++ pilot/utils/tracer/tracer_impl.py | 61 ++- pilot/utils/tracer/tracer_middleware.py | 6 +- 17 files changed, 830 insertions(+), 128 deletions(-) create mode 100644 pilot/utils/tracer/tracer_cli.py diff --git a/docker/base/build_image.sh b/docker/base/build_image.sh index 9ecd9db9d..32846936f 100755 --- a/docker/base/build_image.sh +++ b/docker/base/build_image.sh @@ -15,7 +15,7 @@ IMAGE_NAME_ARGS="" PIP_INDEX_URL="https://pypi.org/simple" # en or zh LANGUAGE="en" -BUILD_LOCAL_CODE="false" +BUILD_LOCAL_CODE="true" LOAD_EXAMPLES="true" BUILD_NETWORK="" DB_GPT_INSTALL_MODEL="default" @@ -26,7 +26,7 @@ usage () { echo " [-n|--image-name image name] Current image name, default: db-gpt" echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" echo " [--language en or zh] You language, default: en" - echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false" + echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true" echo " [--load-examples true or false] Whether to load examples to default database default: true" echo " [--network network name] The network of docker build" echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'" diff --git a/pilot/model/cluster/base.py b/pilot/model/cluster/base.py index 7d97e6bd9..9d22161b1 100644 --- a/pilot/model/cluster/base.py +++ b/pilot/model/cluster/base.py @@ -18,11 +18,13 @@ class PromptRequest(BaseModel): max_new_tokens: int = None stop: str = None echo: bool = True + span_id: str = None class EmbeddingsRequest(BaseModel): model: str input: List[str] + span_id: str = None class WorkerApplyRequest(BaseModel): diff --git a/pilot/model/cluster/embedding/loader.py b/pilot/model/cluster/embedding/loader.py index 63f6c452d..caf4bda9a 100644 --- a/pilot/model/cluster/embedding/loader.py +++ b/pilot/model/cluster/embedding/loader.py @@ -3,6 +3,8 @@ from __future__ import annotations from typing import TYPE_CHECKING from pilot.model.parameter import BaseEmbeddingModelParameters +from pilot.utils.parameter_utils import _get_dict_from_obj +from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName if TYPE_CHECKING: from langchain.embeddings.base import Embeddings @@ -15,13 +17,21 @@ class EmbeddingLoader: def load( self, model_name: str, param: BaseEmbeddingModelParameters ) -> "Embeddings": - # add more models - if model_name in ["proxy_openai", "proxy_azure"]: - from langchain.embeddings import OpenAIEmbeddings + metadata = { + "model_name": model_name, + "run_service": SpanTypeRunName.EMBEDDING_MODEL.value, + "params": _get_dict_from_obj(param), + } + with root_tracer.start_span( + "EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata + ): + # add more models + if model_name in ["proxy_openai", "proxy_azure"]: + from langchain.embeddings import OpenAIEmbeddings - return OpenAIEmbeddings(**param.build_kwargs()) - else: - from langchain.embeddings import HuggingFaceEmbeddings + return OpenAIEmbeddings(**param.build_kwargs()) + else: + from langchain.embeddings import HuggingFaceEmbeddings - kwargs = param.build_kwargs(model_name=param.model_path) - return HuggingFaceEmbeddings(**kwargs) + kwargs = param.build_kwargs(model_name=param.model_path) + return HuggingFaceEmbeddings(**kwargs) diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 5873ab7d0..378fee2ea 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -9,8 +9,8 @@ from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.parameter import ModelParameters from pilot.model.cluster.worker_base import ModelWorker from pilot.utils.model_utils import _clear_model_cache -from pilot.utils.parameter_utils import EnvArgumentParser -from pilot.utils.tracer import root_tracer +from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj +from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName logger = logging.getLogger(__name__) @@ -95,9 +95,20 @@ class DefaultModelWorker(ModelWorker): model_params = self.parse_parameters(command_args) self._model_params = model_params logger.info(f"Begin load model, model params: {model_params}") - self.model, self.tokenizer = self.ml.loader_with_params( - model_params, self.llm_adapter - ) + metadata = { + "model_name": self.model_name, + "model_path": self.model_path, + "model_type": self.llm_adapter.model_type(), + "llm_adapter": str(self.llm_adapter), + "run_service": SpanTypeRunName.MODEL_WORKER, + "params": _get_dict_from_obj(model_params), + } + with root_tracer.start_span( + "DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata + ): + self.model, self.tokenizer = self.ml.loader_with_params( + model_params, self.llm_adapter + ) def stop(self) -> None: if not self.model: @@ -110,7 +121,9 @@ class DefaultModelWorker(ModelWorker): _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: - span = root_tracer.start_span("DefaultModelWorker.generate_stream") + span = root_tracer.start_span( + "DefaultModelWorker.generate_stream", params.get("span_id") + ) try: ( params, @@ -153,7 +166,9 @@ class DefaultModelWorker(ModelWorker): raise NotImplementedError async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: - span = root_tracer.start_span("DefaultModelWorker.async_generate_stream") + span = root_tracer.start_span( + "DefaultModelWorker.async_generate_stream", params.get("span_id") + ) try: ( params, diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index b7b9515c5..bd941781d 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -8,12 +8,13 @@ import sys import time from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict -from typing import Awaitable, Callable, Dict, Iterator, List, Optional +from typing import Awaitable, Callable, Dict, Iterator, List from fastapi import APIRouter, FastAPI from fastapi.responses import StreamingResponse from pilot.component import SystemApp +from pilot.configs.model_config import LOGDIR from pilot.model.base import ( ModelInstance, ModelOutput, @@ -35,8 +36,10 @@ from pilot.utils.parameter_utils import ( EnvArgumentParser, ParameterDescription, _dict_to_command_args, + _get_dict_from_obj, ) from pilot.utils.utils import setup_logging +from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName logger = logging.getLogger(__name__) @@ -293,60 +296,72 @@ class LocalWorkerManager(WorkerManager): self, params: Dict, async_wrapper=None, **kwargs ) -> Iterator[ModelOutput]: """Generate stream result, chat scene""" - try: - worker_run_data = await self._get_model(params) - except Exception as e: - yield ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, - ) - return - async with worker_run_data.semaphore: - if worker_run_data.worker.support_async(): - async for outout in worker_run_data.worker.async_generate_stream( - params - ): - yield outout - else: - if not async_wrapper: - from starlette.concurrency import iterate_in_threadpool + with root_tracer.start_span( + "WorkerManager.generate_stream", params.get("span_id") + ) as span: + params["span_id"] = span.span_id + try: + worker_run_data = await self._get_model(params) + except Exception as e: + yield ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, + ) + return + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + async for outout in worker_run_data.worker.async_generate_stream( + params + ): + yield outout + else: + if not async_wrapper: + from starlette.concurrency import iterate_in_threadpool - async_wrapper = iterate_in_threadpool - async for output in async_wrapper( - worker_run_data.worker.generate_stream(params) - ): - yield output + async_wrapper = iterate_in_threadpool + async for output in async_wrapper( + worker_run_data.worker.generate_stream(params) + ): + yield output async def generate(self, params: Dict) -> ModelOutput: """Generate non stream result""" - try: - worker_run_data = await self._get_model(params) - except Exception as e: - return ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, - ) - async with worker_run_data.semaphore: - if worker_run_data.worker.support_async(): - return await worker_run_data.worker.async_generate(params) - else: - return await self.run_blocking_func( - worker_run_data.worker.generate, params + with root_tracer.start_span( + "WorkerManager.generate", params.get("span_id") + ) as span: + params["span_id"] = span.span_id + try: + worker_run_data = await self._get_model(params) + except Exception as e: + return ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, ) + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_generate(params) + else: + return await self.run_blocking_func( + worker_run_data.worker.generate, params + ) async def embeddings(self, params: Dict) -> List[List[float]]: """Embed input""" - try: - worker_run_data = await self._get_model(params, worker_type="text2vec") - except Exception as e: - raise e - async with worker_run_data.semaphore: - if worker_run_data.worker.support_async(): - return await worker_run_data.worker.async_embeddings(params) - else: - return await self.run_blocking_func( - worker_run_data.worker.embeddings, params - ) + with root_tracer.start_span( + "WorkerManager.embeddings", params.get("span_id") + ) as span: + params["span_id"] = span.span_id + try: + worker_run_data = await self._get_model(params, worker_type="text2vec") + except Exception as e: + raise e + async with worker_run_data.semaphore: + if worker_run_data.worker.support_async(): + return await worker_run_data.worker.async_embeddings(params) + else: + return await self.run_blocking_func( + worker_run_data.worker.embeddings, params + ) def sync_embeddings(self, params: Dict) -> List[List[float]]: worker_run_data = self._sync_get_model(params, worker_type="text2vec") @@ -608,6 +623,9 @@ async def generate_json_stream(params): @router.post("/worker/generate_stream") async def api_generate_stream(request: PromptRequest): params = request.dict(exclude_none=True) + span_id = root_tracer.get_current_span_id() + if "span_id" not in params and span_id: + params["span_id"] = span_id generator = generate_json_stream(params) return StreamingResponse(generator) @@ -615,12 +633,18 @@ async def api_generate_stream(request: PromptRequest): @router.post("/worker/generate") async def api_generate(request: PromptRequest): params = request.dict(exclude_none=True) + span_id = root_tracer.get_current_span_id() + if "span_id" not in params and span_id: + params["span_id"] = span_id return await worker_manager.generate(params) @router.post("/worker/embeddings") async def api_embeddings(request: EmbeddingsRequest): params = request.dict(exclude_none=True) + span_id = root_tracer.get_current_span_id() + if "span_id" not in params and span_id: + params["span_id"] = span_id return await worker_manager.embeddings(params) @@ -801,10 +825,18 @@ def _build_worker(worker_params: ModelWorkerParameters): def _start_local_worker( worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters ): - worker = _build_worker(worker_params) - if not worker_manager.worker_manager: - worker_manager.worker_manager = _create_local_model_manager(worker_params) - worker_manager.worker_manager.add_worker(worker, worker_params) + with root_tracer.start_span( + "WorkerManager._start_local_worker", + span_type=SpanType.RUN, + metadata={ + "run_service": SpanTypeRunName.WORKER_MANAGER, + "params": _get_dict_from_obj(worker_params), + }, + ): + worker = _build_worker(worker_params) + if not worker_manager.worker_manager: + worker_manager.worker_manager = _create_local_model_manager(worker_params) + worker_manager.worker_manager.add_worker(worker, worker_params) def _start_local_embedding_worker( @@ -928,17 +960,17 @@ def run_worker_manager( # Run worker manager independently embedded_mod = False app = _setup_fastapi(worker_params) - _start_local_worker(worker_manager, worker_params) - _start_local_embedding_worker( - worker_manager, embedding_model_name, embedding_model_path - ) - else: - _start_local_worker(worker_manager, worker_params) - _start_local_embedding_worker( - worker_manager, embedding_model_name, embedding_model_path - ) - loop = asyncio.get_event_loop() - loop.run_until_complete(worker_manager.start()) + + system_app = SystemApp(app) + initialize_tracer( + system_app, + os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"), + root_operation_name="DB-GPT-WorkerManager-Entry", + ) + _start_local_worker(worker_manager, worker_params) + _start_local_embedding_worker( + worker_manager, embedding_model_name, embedding_model_path + ) if include_router: app.include_router(router, prefix="/api") @@ -946,6 +978,8 @@ def run_worker_manager( if not embedded_mod: import uvicorn + loop = asyncio.get_event_loop() + loop.run_until_complete(worker_manager.start()) uvicorn.run( app, host=worker_params.host, port=worker_params.port, log_level="info" ) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index dfcc91f27..aba2b627f 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -46,7 +46,7 @@ from pilot.summary.db_summary_client import DBSummaryClient from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.base import FlatSupportedModel -from pilot.utils.tracer import root_tracer +from pilot.utils.tracer import root_tracer, SpanType router = APIRouter() CFG = Config() @@ -367,7 +367,9 @@ async def chat_completions(dialogue: ConversationVo = Body()): print( f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" ) - with root_tracer.start_span("chat_completions", metadata=dialogue.dict()) as _: + with root_tracer.start_span( + "get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() + ): chat: BaseChat = get_chat_instance(dialogue) # background_tasks = BackgroundTasks() # background_tasks.add_task(release_model_semaphore) @@ -419,11 +421,10 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana async def no_stream_generator(chat): - span = root_tracer.start_span("no_stream_generator") - msg = await chat.nostream_call() - msg = msg.replace("\n", "\\n") - yield f"data: {msg}\n\n" - span.end() + with root_tracer.start_span("no_stream_generator"): + msg = await chat.nostream_call() + msg = msg.replace("\n", "\\n") + yield f"data: {msg}\n\n" async def stream_generator(chat, incremental: bool, model_name: str): @@ -467,9 +468,9 @@ async def stream_generator(chat, incremental: bool, model_name: str): yield f"data:{msg}\n\n" previous_response = msg await asyncio.sleep(0.02) - span.end() if incremental: yield "data: [DONE]\n\n" + span.end() chat.current_message.add_ai_message(msg) chat.current_message.add_view_message(msg) chat.memory.append(chat.current_message) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 0ee6309ae..30ec078fc 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -139,7 +139,9 @@ class BaseChat(ABC): def _get_span_metadata(self, payload: Dict) -> Dict: metadata = {k: v for k, v in payload.items()} del metadata["prompt"] - metadata["messages"] = list(map(lambda m: m.dict(), metadata["messages"])) + metadata["messages"] = list( + map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"]) + ) return metadata async def stream_call(self): @@ -152,6 +154,7 @@ class BaseChat(ABC): span = root_tracer.start_span( "BaseChat.stream_call", metadata=self._get_span_metadata(payload) ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -178,6 +181,7 @@ class BaseChat(ABC): span = root_tracer.start_span( "BaseChat.nostream_call", metadata=self._get_span_metadata(payload) ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -185,7 +189,7 @@ class BaseChat(ABC): ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() - with root_tracer.start_span("BaseChat.invoke_worker_manager.generate") as _: + with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): model_output = await worker_manager.generate(payload) ### output parse @@ -206,7 +210,7 @@ class BaseChat(ABC): "ai_response_text": ai_response_text, "prompt_define_response": prompt_define_response, } - with root_tracer.start_span("BaseChat.do_action", metadata=metadata) as _: + with root_tracer.start_span("BaseChat.do_action", metadata=metadata): ### run result = self.do_action(prompt_define_response) diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py index a0a7f029e..a51c2b343 100644 --- a/pilot/scripts/cli_scripts.py +++ b/pilot/scripts/cli_scripts.py @@ -119,6 +119,14 @@ except ImportError as e: logging.warning(f"Integrating dbgpt knowledge command line tool failed: {e}") +try: + from pilot.utils.tracer.tracer_cli import trace_cli_group + + add_command_alias(trace_cli_group, name="trace", parent_group=cli) +except ImportError as e: + logging.warning(f"Integrating dbgpt trace command line tool failed: {e}") + + def main(): return cli() diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index c72b0e5fd..91fb9de6e 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -8,8 +8,6 @@ from pilot.component import ComponentType, SystemApp from pilot.utils.executor_utils import DefaultExecutorFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.base import WebWerverParameters -from pilot.configs.model_config import LOGDIR -from pilot.utils.tracer import root_tracer, initialize_tracer if TYPE_CHECKING: from langchain.embeddings.base import Embeddings @@ -26,8 +24,6 @@ def initialize_components( ): from pilot.model.cluster.controller.controller import controller - initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl")) - # Register global default executor factory first system_app.register(DefaultExecutorFactory) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 0a5b19933..e58c61756 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -7,7 +7,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi sys.path.append(ROOT_PATH) import signal from pilot.configs.config import Config -from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG +from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR from pilot.component import SystemApp from pilot.server.base import ( @@ -38,6 +38,8 @@ from pilot.utils.utils import ( _get_logging_level, logging_str_to_uvicorn_level, ) +from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName +from pilot.utils.parameter_utils import _get_dict_from_obj static_file_path = os.path.join(os.getcwd(), "server/static") @@ -98,17 +100,21 @@ def mount_static_files(app): app.add_exception_handler(RequestValidationError, validation_exception_handler) +def _get_webserver_params(args: List[str] = None): + from pilot.utils.parameter_utils import EnvArgumentParser + + parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( + WebWerverParameters + ) + return WebWerverParameters(**vars(parser.parse_args(args=args))) + + def initialize_app(param: WebWerverParameters = None, args: List[str] = None): """Initialize app If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook. """ if not param: - from pilot.utils.parameter_utils import EnvArgumentParser - - parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( - WebWerverParameters - ) - param = WebWerverParameters(**vars(parser.parse_args(args=args))) + param = _get_webserver_params(args) if not param.log_level: param.log_level = _get_logging_level() @@ -127,7 +133,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) - model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] + model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: @@ -174,8 +180,20 @@ def run_uvicorn(param: WebWerverParameters): def run_webserver(param: WebWerverParameters = None): - param = initialize_app(param) - run_uvicorn(param) + if not param: + param = _get_webserver_params() + initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl")) + + with root_tracer.start_span( + "run_webserver", + span_type=SpanType.RUN, + metadata={ + "run_service": SpanTypeRunName.WEBSERVER, + "params": _get_dict_from_obj(param), + }, + ): + param = initialize_app(param) + run_uvicorn(param) if __name__ == "__main__": diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d521a062d..1a2dd49ce 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -13,7 +13,7 @@ from pilot.model.cluster import run_worker_manager CFG = Config() -model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] +model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) if __name__ == "__main__": run_worker_manager( diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index 8acba8881..fbf9c5fb5 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -1,6 +1,6 @@ import argparse import os -from dataclasses import dataclass, fields, MISSING, asdict, field +from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass from typing import Any, List, Optional, Type, Union, Callable, Dict from collections import OrderedDict @@ -590,6 +590,20 @@ def _extract_parameter_details( return descriptions +def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]: + if not obj: + return None + if is_dataclass(type(obj)): + params = {} + for field_info in fields(obj): + value = _get_simple_privacy_field_value(obj, field_info) + params[field_info.name] = value + return params + if isinstance(obj, dict): + return obj + return default_value + + class _SimpleArgParser: def __init__(self, *args): self.params = {arg.replace("_", "-"): None for arg in args} diff --git a/pilot/utils/tracer/__init__.py b/pilot/utils/tracer/__init__.py index 0163b443c..16509ff43 100644 --- a/pilot/utils/tracer/__init__.py +++ b/pilot/utils/tracer/__init__.py @@ -1,5 +1,7 @@ from pilot.utils.tracer.base import ( + SpanType, Span, + SpanTypeRunName, Tracer, SpanStorage, SpanStorageType, @@ -14,7 +16,9 @@ from pilot.utils.tracer.tracer_impl import ( ) __all__ = [ + "SpanType", "Span", + "SpanTypeRunName", "Tracer", "SpanStorage", "SpanStorageType", diff --git a/pilot/utils/tracer/base.py b/pilot/utils/tracer/base.py index 9fd750b0c..e227d6314 100644 --- a/pilot/utils/tracer/base.py +++ b/pilot/utils/tracer/base.py @@ -10,22 +10,41 @@ from datetime import datetime from pilot.component import BaseComponent, SystemApp, ComponentType +class SpanType(str, Enum): + BASE = "base" + RUN = "run" + CHAT = "chat" + + +class SpanTypeRunName(str, Enum): + WEBSERVER = "Webserver" + WORKER_MANAGER = "WorkerManager" + MODEL_WORKER = "ModelWorker" + EMBEDDING_MODEL = "EmbeddingModel" + + @staticmethod + def values(): + return [item.value for item in SpanTypeRunName] + + class Span: """Represents a unit of work that is being traced. This can be any operation like a function call or a database query. """ - span_type: str = "base" - def __init__( self, trace_id: str, span_id: str, + span_type: SpanType = None, parent_span_id: str = None, operation_name: str = None, metadata: Dict = None, end_caller: Callable[[Span], None] = None, ): + if not span_type: + span_type = SpanType.BASE + self.span_type = span_type # The unique identifier for the entire trace self.trace_id = trace_id # Unique identifier for this span within the trace @@ -65,7 +84,7 @@ class Span: def to_dict(self) -> Dict: return { - "span_type": self.span_type, + "span_type": self.span_type.value, "trace_id": self.trace_id, "span_id": self.span_id, "parent_span_id": self.parent_span_id, @@ -124,7 +143,11 @@ class Tracer(BaseComponent, ABC): @abstractmethod def start_span( - self, operation_name: str, parent_span_id: str = None, metadata: Dict = None + self, + operation_name: str, + parent_span_id: str = None, + span_type: SpanType = None, + metadata: Dict = None, ) -> Span: """Begin a new span for the given operation. If provided, the span will be a child of the span with the given parent_span_id. @@ -158,4 +181,4 @@ class Tracer(BaseComponent, ABC): @dataclass class TracerContext: - span_id: str + span_id: Optional[str] = None diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py new file mode 100644 index 000000000..e99bba049 --- /dev/null +++ b/pilot/utils/tracer/tracer_cli.py @@ -0,0 +1,540 @@ +import os +import click +import logging +import glob +import json +from datetime import datetime +from typing import Iterable, Dict, Callable +from pilot.configs.model_config import LOGDIR +from pilot.utils.tracer import SpanType, SpanTypeRunName + +logger = logging.getLogger("dbgpt_cli") + + +_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl") + + +@click.group("trace") +def trace_cli_group(): + """Analyze and visualize trace spans.""" + pass + + +@trace_cli_group.command() +@click.option( + "--trace_id", + required=False, + type=str, + default=None, + show_default=True, + help="Specify the trace ID to list", +) +@click.option( + "--span_id", + required=False, + type=str, + default=None, + show_default=True, + help="Specify the Span ID to list.", +) +@click.option( + "--span_type", + required=False, + type=str, + default=None, + show_default=True, + help="Specify the Span Type to list.", +) +@click.option( + "--parent_span_id", + required=False, + type=str, + default=None, + show_default=True, + help="Specify the Parent Span ID to list.", +) +@click.option( + "--search", + required=False, + type=str, + default=None, + show_default=True, + help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.", +) +@click.option( + "-l", + "--limit", + type=int, + default=20, + help="Limit the number of recent span displayed.", +) +@click.option( + "--start_time", + type=str, + help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"', +) +@click.option( + "--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"' +) +@click.option( + "--desc", + required=False, + type=bool, + default=False, + is_flag=True, + help="Whether to use reverse sorting. By default, sorting is based on start time.", +) +@click.option( + "--output", + required=False, + type=click.Choice(["text", "html", "csv", "latex", "json"]), + default="text", + help="The output format", +) +@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True)) +def list( + trace_id: str, + span_id: str, + span_type: str, + parent_span_id: str, + search: str, + limit: int, + start_time: str, + end_time: str, + desc: bool, + output: str, + files=None, +): + """List your trace spans""" + from prettytable import PrettyTable + + # If no files are explicitly specified, use the default pattern to get them + spans = read_spans_from_files(files) + + if trace_id: + spans = filter(lambda s: s["trace_id"] == trace_id, spans) + if span_id: + spans = filter(lambda s: s["span_id"] == span_id, spans) + if span_type: + spans = filter(lambda s: s["span_type"] == span_type, spans) + if parent_span_id: + spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans) + # Filter spans based on the start and end times + if start_time: + start_dt = _parse_datetime(start_time) + spans = filter( + lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans + ) + + if end_time: + end_dt = _parse_datetime(end_time) + spans = filter( + lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans + ) + + if search: + spans = filter(_new_search_span_func(search), spans) + + # Sort spans based on the start time + spans = sorted( + spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc + )[:limit] + + table = PrettyTable( + ["Trace ID", "Span ID", "Operation Name", "Conversation UID"], + ) + + for sp in spans: + conv_uid = None + if "metadata" in sp and sp: + metadata = sp["metadata"] + if isinstance(metadata, dict): + conv_uid = metadata.get("conv_uid") + table.add_row( + [ + sp.get("trace_id"), + sp.get("span_id"), + # sp.get("parent_span_id"), + sp.get("operation_name"), + conv_uid, + ] + ) + out_kwargs = {"ensure_ascii": False} if output == "json" else {} + print(table.get_formatted_string(out_format=output, **out_kwargs)) + + +@trace_cli_group.command() +@click.option( + "--trace_id", + required=True, + type=str, + help="Specify the trace ID to list", +) +@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True)) +def tree(trace_id: str, files): + """Display trace links as a tree""" + hierarchy = _view_trace_hierarchy(trace_id, files) + _print_trace_hierarchy(hierarchy) + + +@trace_cli_group.command() +@click.option( + "--trace_id", + required=False, + type=str, + default=None, + help="Specify the trace ID to analyze. If None, show latest conversation details", +) +@click.option( + "--tree", + required=False, + type=bool, + default=False, + is_flag=True, + help="Display trace spans as a tree", +) +@click.option( + "--hide_run_params", + required=False, + type=bool, + default=False, + is_flag=True, + help="Hide run params", +) +@click.option( + "--output", + required=False, + type=click.Choice(["text", "html", "csv", "latex", "json"]), + default="text", + help="The output format", +) +@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True)) +def chat(trace_id: str, tree: bool, hide_run_params: bool, output: str, files): + """Show conversation details""" + from prettytable import PrettyTable + + spans = read_spans_from_files(files) + + # Sort by start time + spans = sorted( + spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True + ) + spans = [sp for sp in spans] + service_spans = {} + service_names = set(SpanTypeRunName.values()) + found_trace_id = None + for sp in spans: + span_type = sp["span_type"] + metadata = sp.get("metadata") + if span_type == SpanType.RUN: + service_name = metadata["run_service"] + service_spans[service_name] = sp.copy() + if set(service_spans.keys()) == service_names and found_trace_id: + break + elif span_type == SpanType.CHAT and not found_trace_id: + if not trace_id: + found_trace_id = sp["trace_id"] + if trace_id and trace_id == sp["trace_id"]: + found_trace_id = trace_id + + service_tables = {} + out_kwargs = {"ensure_ascii": False} if output == "json" else {} + for service_name, sp in service_spans.items(): + metadata = sp["metadata"] + table = PrettyTable(["Config Key", "Config Value"], title=service_name) + for k, v in metadata["params"].items(): + table.add_row([k, v]) + service_tables[service_name] = table + + if not hide_run_params: + merged_table1 = merge_tables_horizontally( + [ + service_tables.get(SpanTypeRunName.WEBSERVER.value), + service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value), + ] + ) + merged_table2 = merge_tables_horizontally( + [ + service_tables.get(SpanTypeRunName.MODEL_WORKER), + service_tables.get(SpanTypeRunName.WORKER_MANAGER), + ] + ) + if output == "text": + print(merged_table1) + print(merged_table2) + else: + for service_name, table in service_tables.items(): + print(table.get_formatted_string(out_format=output, **out_kwargs)) + + if not found_trace_id: + print(f"Can't found conversation with trace_id: {trace_id}") + return + trace_id = found_trace_id + + trace_spans = [span for span in spans if span["trace_id"] == trace_id] + trace_spans = [s for s in reversed(trace_spans)] + hierarchy = _build_trace_hierarchy(trace_spans) + if tree: + print("\nInvoke Trace Tree:\n") + _print_trace_hierarchy(hierarchy) + + trace_spans = _get_ordered_trace_from(hierarchy) + table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details") + split_long_text = output == "text" + + for sp in trace_spans: + op = sp["operation_name"] + metadata = sp.get("metadata") + if op == "get_chat_instance" and not sp["end_time"]: + table.add_row(["trace_id", trace_id]) + table.add_row(["span_id", sp["span_id"]]) + table.add_row(["conv_uid", metadata.get("conv_uid")]) + table.add_row(["user_input", metadata.get("user_input")]) + table.add_row(["chat_mode", metadata.get("chat_mode")]) + table.add_row(["select_param", metadata.get("select_param")]) + table.add_row(["model_name", metadata.get("model_name")]) + if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]: + if not sp["end_time"]: + table.add_row(["temperature", metadata.get("temperature")]) + table.add_row(["max_new_tokens", metadata.get("max_new_tokens")]) + table.add_row(["echo", metadata.get("echo")]) + elif "error" in metadata: + table.add_row(["BaseChat Error", metadata.get("error")]) + if op == "BaseChat.nostream_call" and not sp["end_time"]: + if "model_output" in metadata: + table.add_row( + [ + "BaseChat model_output", + split_string_by_terminal_width( + metadata.get("model_output").get("text"), + split=split_long_text, + ), + ] + ) + if "ai_response_text" in metadata: + table.add_row( + [ + "BaseChat ai_response_text", + split_string_by_terminal_width( + metadata.get("ai_response_text"), split=split_long_text + ), + ] + ) + if "prompt_define_response" in metadata: + table.add_row( + [ + "BaseChat prompt_define_response", + split_string_by_terminal_width( + metadata.get("prompt_define_response"), + split=split_long_text, + ), + ] + ) + if op == "DefaultModelWorker_call.generate_stream_func": + if not sp["end_time"]: + table.add_row(["llm_adapter", metadata.get("llm_adapter")]) + table.add_row( + [ + "User prompt", + split_string_by_terminal_width( + metadata.get("prompt"), split=split_long_text + ), + ] + ) + else: + table.add_row( + [ + "Model output", + split_string_by_terminal_width(metadata.get("output")), + ] + ) + if ( + op + in [ + "DefaultModelWorker.async_generate_stream", + "DefaultModelWorker.generate_stream", + ] + and metadata + and "error" in metadata + ): + table.add_row(["Model Error", metadata.get("error")]) + print(table.get_formatted_string(out_format=output, **out_kwargs)) + + +def read_spans_from_files(files=None) -> Iterable[Dict]: + """ + Reads spans from multiple files based on the provided file paths. + """ + if not files: + files = [_DEFAULT_FILE_PATTERN] + + for filepath in files: + for filename in glob.glob(filepath): + with open(filename, "r") as file: + for line in file: + yield json.loads(line) + + +def _new_search_span_func(search: str): + def func(span: Dict) -> bool: + items = [span["trace_id"], span["span_id"], span["parent_span_id"]] + if "operation_name" in span: + items.append(span["operation_name"]) + if "metadata" in span: + metadata = span["metadata"] + if isinstance(metadata, dict): + for k, v in metadata.items(): + items.append(k) + items.append(v) + return any(search in str(item) for item in items if item) + + return func + + +def _parse_datetime(dt_str): + """Parse a datetime string to a datetime object.""" + return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f") + + +def _build_trace_hierarchy(spans, parent_span_id=None, indent=0): + # Current spans + current_level_spans = [ + span + for span in spans + if span["parent_span_id"] == parent_span_id and span["end_time"] is None + ] + + hierarchy = [] + + for start_span in current_level_spans: + # Find end span + end_span = next( + ( + span + for span in spans + if span["span_id"] == start_span["span_id"] + and span["end_time"] is not None + ), + None, + ) + entry = { + "operation_name": start_span["operation_name"], + "parent_span_id": start_span["parent_span_id"], + "span_id": start_span["span_id"], + "start_time": start_span["start_time"], + "end_time": start_span["end_time"], + "metadata": start_span["metadata"], + "children": _build_trace_hierarchy( + spans, start_span["span_id"], indent + 1 + ), + } + hierarchy.append(entry) + + # Append end span + if end_span: + entry_end = { + "operation_name": end_span["operation_name"], + "parent_span_id": end_span["parent_span_id"], + "span_id": end_span["span_id"], + "start_time": end_span["start_time"], + "end_time": end_span["end_time"], + "metadata": end_span["metadata"], + "children": [], + } + hierarchy.append(entry_end) + + return hierarchy + + +def _view_trace_hierarchy(trace_id, files=None): + """Find and display the calls of the entire link based on the given trace_id""" + spans = read_spans_from_files(files) + trace_spans = [span for span in spans if span["trace_id"] == trace_id] + hierarchy = _build_trace_hierarchy(trace_spans) + return hierarchy + + +def _print_trace_hierarchy(hierarchy, indent=0): + """Print link hierarchy""" + for entry in hierarchy: + print( + " " * indent + + f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})" + ) + _print_trace_hierarchy(entry["children"], indent + 1) + + +def _get_ordered_trace_from(hierarchy): + traces = [] + + def func(items): + for item in items: + traces.append(item) + func(item["children"]) + + func(hierarchy) + return traces + + +def _print(service_spans: Dict): + for names in [ + [SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL], + [SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER], + ]: + pass + + +def merge_tables_horizontally(tables): + from prettytable import PrettyTable + + if not tables: + return None + + tables = [t for t in tables if t] + if not tables: + return None + + max_rows = max(len(table._rows) for table in tables) + + merged_table = PrettyTable() + + new_field_names = [] + for table in tables: + new_field_names.extend( + [ + f"{name} ({table.title})" if table.title else f"{name}" + for name in table.field_names + ] + ) + + merged_table.field_names = new_field_names + + for i in range(max_rows): + merged_row = [] + for table in tables: + if i < len(table._rows): + merged_row.extend(table._rows[i]) + else: + # Fill empty cells for shorter tables + merged_row.extend([""] * len(table.field_names)) + merged_table.add_row(merged_row) + + return merged_table + + +def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"): + """ + Split a string into substrings based on the current terminal width. + + Parameters: + - s: the input string + """ + if not split: + return s + if not max_len: + try: + max_len = int(os.get_terminal_size().columns * 0.8) + except OSError: + # Default to 80 columns if the terminal size can't be determined + max_len = 100 + return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)]) diff --git a/pilot/utils/tracer/tracer_impl.py b/pilot/utils/tracer/tracer_impl.py index 4f026e2b5..bda25ab4d 100644 --- a/pilot/utils/tracer/tracer_impl.py +++ b/pilot/utils/tracer/tracer_impl.py @@ -4,6 +4,7 @@ from functools import wraps from pilot.component import SystemApp, ComponentType from pilot.utils.tracer.base import ( + SpanType, Span, Tracer, SpanStorage, @@ -32,14 +33,23 @@ class DefaultTracer(Tracer): self._get_current_storage().append_span(span) def start_span( - self, operation_name: str, parent_span_id: str = None, metadata: Dict = None + self, + operation_name: str, + parent_span_id: str = None, + span_type: SpanType = None, + metadata: Dict = None, ) -> Span: trace_id = ( self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0] ) span_id = f"{trace_id}:{self._new_uuid()}" span = Span( - trace_id, span_id, parent_span_id, operation_name, metadata=metadata + trace_id, + span_id, + span_type, + parent_span_id, + operation_name, + metadata=metadata, ) if self._span_storage_type in [ @@ -81,11 +91,13 @@ class DefaultTracer(Tracer): class TracerManager: + """The manager of current tracer""" + def __init__(self) -> None: self._system_app: Optional[SystemApp] = None self._trace_context_var: ContextVar[TracerContext] = ContextVar( "trace_context", - default=TracerContext(span_id="default_trace_id:default_span_id"), + default=TracerContext(), ) def initialize( @@ -101,18 +113,23 @@ class TracerManager: return self._system_app.get_component(ComponentType.TRACER, Tracer, None) def start_span( - self, operation_name: str, parent_span_id: str = None, metadata: Dict = None + self, + operation_name: str, + parent_span_id: str = None, + span_type: SpanType = None, + metadata: Dict = None, ) -> Span: + """Start a new span with operation_name + This method must not throw an exception under any case and try not to block as much as possible + """ tracer = self._get_tracer() if not tracer: return Span("empty_span", "empty_span") if not parent_span_id: - current_span = self.get_current_span() - if current_span: - parent_span_id = current_span.span_id - else: - parent_span_id = self._trace_context_var.get().span_id - return tracer.start_span(operation_name, parent_span_id, metadata) + parent_span_id = self.get_current_span_id() + return tracer.start_span( + operation_name, parent_span_id, span_type=span_type, metadata=metadata + ) def end_span(self, span: Span, **kwargs): tracer = self._get_tracer() @@ -126,15 +143,22 @@ class TracerManager: return None return tracer.get_current_span() + def get_current_span_id(self) -> Optional[str]: + current_span = self.get_current_span() + if current_span: + return current_span.span_id + ctx = self._trace_context_var.get() + return ctx.span_id if ctx else None + root_tracer: TracerManager = TracerManager() -def trace(operation_name: str): +def trace(operation_name: str, **trace_kwargs): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - with root_tracer.start_span(operation_name) as span: + with root_tracer.start_span(operation_name, **trace_kwargs): return await func(*args, **kwargs) return wrapper @@ -142,14 +166,18 @@ def trace(operation_name: str): return decorator -def initialize_tracer(system_app: SystemApp, tracer_filename: str): +def initialize_tracer( + system_app: SystemApp, + tracer_filename: str, + root_operation_name: str = "DB-GPT-Web-Entry", +): if not system_app: return from pilot.utils.tracer.span_storage import FileSpanStorage trace_context_var = ContextVar( "trace_context", - default=TracerContext(span_id="default_trace_id:default_span_id"), + default=TracerContext(), ) tracer = DefaultTracer(system_app) @@ -160,5 +188,8 @@ def initialize_tracer(system_app: SystemApp, tracer_filename: str): from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware system_app.app.add_middleware( - TraceIDMiddleware, trace_context_var=trace_context_var, tracer=tracer + TraceIDMiddleware, + trace_context_var=trace_context_var, + tracer=tracer, + root_operation_name=root_operation_name, ) diff --git a/pilot/utils/tracer/tracer_middleware.py b/pilot/utils/tracer/tracer_middleware.py index a7936f4cb..41f1b64dc 100644 --- a/pilot/utils/tracer/tracer_middleware.py +++ b/pilot/utils/tracer/tracer_middleware.py @@ -16,12 +16,14 @@ class TraceIDMiddleware(BaseHTTPMiddleware): app: ASGIApp, trace_context_var: ContextVar[TracerContext], tracer: Tracer, + root_operation_name: str = "DB-GPT-Web-Entry", include_prefix: str = "/api", exclude_paths=_DEFAULT_EXCLUDE_PATHS, ): super().__init__(app) self.trace_context_var = trace_context_var self.tracer = tracer + self.root_operation_name = root_operation_name self.include_prefix = include_prefix self.exclude_paths = exclude_paths @@ -37,7 +39,7 @@ class TraceIDMiddleware(BaseHTTPMiddleware): # self.trace_context_var.set(TracerContext(span_id=span_id)) with self.tracer.start_span( - "DB-GPT-Web-Entry", span_id, metadata={"path": request.url.path} - ) as _: + self.root_operation_name, span_id, metadata={"path": request.url.path} + ): response = await call_next(request) return response