mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 23:01:38 +00:00
feat(core): New command line for analyze and visualize trace spans
This commit is contained in:
parent
1e919aeef3
commit
2a46909eac
@ -15,7 +15,7 @@ IMAGE_NAME_ARGS=""
|
|||||||
PIP_INDEX_URL="https://pypi.org/simple"
|
PIP_INDEX_URL="https://pypi.org/simple"
|
||||||
# en or zh
|
# en or zh
|
||||||
LANGUAGE="en"
|
LANGUAGE="en"
|
||||||
BUILD_LOCAL_CODE="false"
|
BUILD_LOCAL_CODE="true"
|
||||||
LOAD_EXAMPLES="true"
|
LOAD_EXAMPLES="true"
|
||||||
BUILD_NETWORK=""
|
BUILD_NETWORK=""
|
||||||
DB_GPT_INSTALL_MODEL="default"
|
DB_GPT_INSTALL_MODEL="default"
|
||||||
@ -26,7 +26,7 @@ usage () {
|
|||||||
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
||||||
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
||||||
echo " [--language en or zh] You language, default: en"
|
echo " [--language en or zh] You language, default: en"
|
||||||
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
|
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true"
|
||||||
echo " [--load-examples true or false] Whether to load examples to default database default: true"
|
echo " [--load-examples true or false] Whether to load examples to default database default: true"
|
||||||
echo " [--network network name] The network of docker build"
|
echo " [--network network name] The network of docker build"
|
||||||
echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'"
|
echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'"
|
||||||
|
@ -18,11 +18,13 @@ class PromptRequest(BaseModel):
|
|||||||
max_new_tokens: int = None
|
max_new_tokens: int = None
|
||||||
stop: str = None
|
stop: str = None
|
||||||
echo: bool = True
|
echo: bool = True
|
||||||
|
span_id: str = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
input: List[str]
|
input: List[str]
|
||||||
|
span_id: str = None
|
||||||
|
|
||||||
|
|
||||||
class WorkerApplyRequest(BaseModel):
|
class WorkerApplyRequest(BaseModel):
|
||||||
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from pilot.model.parameter import BaseEmbeddingModelParameters
|
from pilot.model.parameter import BaseEmbeddingModelParameters
|
||||||
|
from pilot.utils.parameter_utils import _get_dict_from_obj
|
||||||
|
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -15,13 +17,21 @@ class EmbeddingLoader:
|
|||||||
def load(
|
def load(
|
||||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||||
) -> "Embeddings":
|
) -> "Embeddings":
|
||||||
# add more models
|
metadata = {
|
||||||
if model_name in ["proxy_openai", "proxy_azure"]:
|
"model_name": model_name,
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||||
|
"params": _get_dict_from_obj(param),
|
||||||
|
}
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"EmbeddingLoader.load", span_type=SpanType.RUN, metadata=metadata
|
||||||
|
):
|
||||||
|
# add more models
|
||||||
|
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
return OpenAIEmbeddings(**param.build_kwargs())
|
return OpenAIEmbeddings(**param.build_kwargs())
|
||||||
else:
|
else:
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||||
return HuggingFaceEmbeddings(**kwargs)
|
return HuggingFaceEmbeddings(**kwargs)
|
||||||
|
@ -9,8 +9,8 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
|||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.utils.model_utils import _clear_model_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||||
from pilot.utils.tracer import root_tracer
|
from pilot.utils.tracer import root_tracer, SpanType, SpanTypeRunName
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -95,9 +95,20 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
model_params = self.parse_parameters(command_args)
|
model_params = self.parse_parameters(command_args)
|
||||||
self._model_params = model_params
|
self._model_params = model_params
|
||||||
logger.info(f"Begin load model, model params: {model_params}")
|
logger.info(f"Begin load model, model params: {model_params}")
|
||||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
metadata = {
|
||||||
model_params, self.llm_adapter
|
"model_name": self.model_name,
|
||||||
)
|
"model_path": self.model_path,
|
||||||
|
"model_type": self.llm_adapter.model_type(),
|
||||||
|
"llm_adapter": str(self.llm_adapter),
|
||||||
|
"run_service": SpanTypeRunName.MODEL_WORKER,
|
||||||
|
"params": _get_dict_from_obj(model_params),
|
||||||
|
}
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"DefaultModelWorker.start", span_type=SpanType.RUN, metadata=metadata
|
||||||
|
):
|
||||||
|
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||||
|
model_params, self.llm_adapter
|
||||||
|
)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
@ -110,7 +121,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
_clear_model_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
span = root_tracer.start_span("DefaultModelWorker.generate_stream")
|
span = root_tracer.start_span(
|
||||||
|
"DefaultModelWorker.generate_stream", params.get("span_id")
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
params,
|
params,
|
||||||
@ -153,7 +166,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
span = root_tracer.start_span("DefaultModelWorker.async_generate_stream")
|
span = root_tracer.start_span(
|
||||||
|
"DefaultModelWorker.async_generate_stream", params.get("span_id")
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
params,
|
params,
|
||||||
|
@ -8,12 +8,13 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import (
|
from pilot.model.base import (
|
||||||
ModelInstance,
|
ModelInstance,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@ -35,8 +36,10 @@ from pilot.utils.parameter_utils import (
|
|||||||
EnvArgumentParser,
|
EnvArgumentParser,
|
||||||
ParameterDescription,
|
ParameterDescription,
|
||||||
_dict_to_command_args,
|
_dict_to_command_args,
|
||||||
|
_get_dict_from_obj,
|
||||||
)
|
)
|
||||||
from pilot.utils.utils import setup_logging
|
from pilot.utils.utils import setup_logging
|
||||||
|
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -293,60 +296,72 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
self, params: Dict, async_wrapper=None, **kwargs
|
self, params: Dict, async_wrapper=None, **kwargs
|
||||||
) -> Iterator[ModelOutput]:
|
) -> Iterator[ModelOutput]:
|
||||||
"""Generate stream result, chat scene"""
|
"""Generate stream result, chat scene"""
|
||||||
try:
|
with root_tracer.start_span(
|
||||||
worker_run_data = await self._get_model(params)
|
"WorkerManager.generate_stream", params.get("span_id")
|
||||||
except Exception as e:
|
) as span:
|
||||||
yield ModelOutput(
|
params["span_id"] = span.span_id
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
try:
|
||||||
error_code=0,
|
worker_run_data = await self._get_model(params)
|
||||||
)
|
except Exception as e:
|
||||||
return
|
yield ModelOutput(
|
||||||
async with worker_run_data.semaphore:
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
if worker_run_data.worker.support_async():
|
error_code=0,
|
||||||
async for outout in worker_run_data.worker.async_generate_stream(
|
)
|
||||||
params
|
return
|
||||||
):
|
async with worker_run_data.semaphore:
|
||||||
yield outout
|
if worker_run_data.worker.support_async():
|
||||||
else:
|
async for outout in worker_run_data.worker.async_generate_stream(
|
||||||
if not async_wrapper:
|
params
|
||||||
from starlette.concurrency import iterate_in_threadpool
|
):
|
||||||
|
yield outout
|
||||||
|
else:
|
||||||
|
if not async_wrapper:
|
||||||
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
|
|
||||||
async_wrapper = iterate_in_threadpool
|
async_wrapper = iterate_in_threadpool
|
||||||
async for output in async_wrapper(
|
async for output in async_wrapper(
|
||||||
worker_run_data.worker.generate_stream(params)
|
worker_run_data.worker.generate_stream(params)
|
||||||
):
|
):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
async def generate(self, params: Dict) -> ModelOutput:
|
async def generate(self, params: Dict) -> ModelOutput:
|
||||||
"""Generate non stream result"""
|
"""Generate non stream result"""
|
||||||
try:
|
with root_tracer.start_span(
|
||||||
worker_run_data = await self._get_model(params)
|
"WorkerManager.generate", params.get("span_id")
|
||||||
except Exception as e:
|
) as span:
|
||||||
return ModelOutput(
|
params["span_id"] = span.span_id
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
try:
|
||||||
error_code=0,
|
worker_run_data = await self._get_model(params)
|
||||||
)
|
except Exception as e:
|
||||||
async with worker_run_data.semaphore:
|
return ModelOutput(
|
||||||
if worker_run_data.worker.support_async():
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
return await worker_run_data.worker.async_generate(params)
|
error_code=0,
|
||||||
else:
|
|
||||||
return await self.run_blocking_func(
|
|
||||||
worker_run_data.worker.generate, params
|
|
||||||
)
|
)
|
||||||
|
async with worker_run_data.semaphore:
|
||||||
|
if worker_run_data.worker.support_async():
|
||||||
|
return await worker_run_data.worker.async_generate(params)
|
||||||
|
else:
|
||||||
|
return await self.run_blocking_func(
|
||||||
|
worker_run_data.worker.generate, params
|
||||||
|
)
|
||||||
|
|
||||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
"""Embed input"""
|
"""Embed input"""
|
||||||
try:
|
with root_tracer.start_span(
|
||||||
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
"WorkerManager.embeddings", params.get("span_id")
|
||||||
except Exception as e:
|
) as span:
|
||||||
raise e
|
params["span_id"] = span.span_id
|
||||||
async with worker_run_data.semaphore:
|
try:
|
||||||
if worker_run_data.worker.support_async():
|
worker_run_data = await self._get_model(params, worker_type="text2vec")
|
||||||
return await worker_run_data.worker.async_embeddings(params)
|
except Exception as e:
|
||||||
else:
|
raise e
|
||||||
return await self.run_blocking_func(
|
async with worker_run_data.semaphore:
|
||||||
worker_run_data.worker.embeddings, params
|
if worker_run_data.worker.support_async():
|
||||||
)
|
return await worker_run_data.worker.async_embeddings(params)
|
||||||
|
else:
|
||||||
|
return await self.run_blocking_func(
|
||||||
|
worker_run_data.worker.embeddings, params
|
||||||
|
)
|
||||||
|
|
||||||
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
def sync_embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
worker_run_data = self._sync_get_model(params, worker_type="text2vec")
|
||||||
@ -608,6 +623,9 @@ async def generate_json_stream(params):
|
|||||||
@router.post("/worker/generate_stream")
|
@router.post("/worker/generate_stream")
|
||||||
async def api_generate_stream(request: PromptRequest):
|
async def api_generate_stream(request: PromptRequest):
|
||||||
params = request.dict(exclude_none=True)
|
params = request.dict(exclude_none=True)
|
||||||
|
span_id = root_tracer.get_current_span_id()
|
||||||
|
if "span_id" not in params and span_id:
|
||||||
|
params["span_id"] = span_id
|
||||||
generator = generate_json_stream(params)
|
generator = generate_json_stream(params)
|
||||||
return StreamingResponse(generator)
|
return StreamingResponse(generator)
|
||||||
|
|
||||||
@ -615,12 +633,18 @@ async def api_generate_stream(request: PromptRequest):
|
|||||||
@router.post("/worker/generate")
|
@router.post("/worker/generate")
|
||||||
async def api_generate(request: PromptRequest):
|
async def api_generate(request: PromptRequest):
|
||||||
params = request.dict(exclude_none=True)
|
params = request.dict(exclude_none=True)
|
||||||
|
span_id = root_tracer.get_current_span_id()
|
||||||
|
if "span_id" not in params and span_id:
|
||||||
|
params["span_id"] = span_id
|
||||||
return await worker_manager.generate(params)
|
return await worker_manager.generate(params)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/worker/embeddings")
|
@router.post("/worker/embeddings")
|
||||||
async def api_embeddings(request: EmbeddingsRequest):
|
async def api_embeddings(request: EmbeddingsRequest):
|
||||||
params = request.dict(exclude_none=True)
|
params = request.dict(exclude_none=True)
|
||||||
|
span_id = root_tracer.get_current_span_id()
|
||||||
|
if "span_id" not in params and span_id:
|
||||||
|
params["span_id"] = span_id
|
||||||
return await worker_manager.embeddings(params)
|
return await worker_manager.embeddings(params)
|
||||||
|
|
||||||
|
|
||||||
@ -801,10 +825,18 @@ def _build_worker(worker_params: ModelWorkerParameters):
|
|||||||
def _start_local_worker(
|
def _start_local_worker(
|
||||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||||
):
|
):
|
||||||
worker = _build_worker(worker_params)
|
with root_tracer.start_span(
|
||||||
if not worker_manager.worker_manager:
|
"WorkerManager._start_local_worker",
|
||||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
span_type=SpanType.RUN,
|
||||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
metadata={
|
||||||
|
"run_service": SpanTypeRunName.WORKER_MANAGER,
|
||||||
|
"params": _get_dict_from_obj(worker_params),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
worker = _build_worker(worker_params)
|
||||||
|
if not worker_manager.worker_manager:
|
||||||
|
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||||
|
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||||
|
|
||||||
|
|
||||||
def _start_local_embedding_worker(
|
def _start_local_embedding_worker(
|
||||||
@ -928,17 +960,17 @@ def run_worker_manager(
|
|||||||
# Run worker manager independently
|
# Run worker manager independently
|
||||||
embedded_mod = False
|
embedded_mod = False
|
||||||
app = _setup_fastapi(worker_params)
|
app = _setup_fastapi(worker_params)
|
||||||
_start_local_worker(worker_manager, worker_params)
|
|
||||||
_start_local_embedding_worker(
|
system_app = SystemApp(app)
|
||||||
worker_manager, embedding_model_name, embedding_model_path
|
initialize_tracer(
|
||||||
)
|
system_app,
|
||||||
else:
|
os.path.join(LOGDIR, "dbgpt_model_worker_manager_tracer.jsonl"),
|
||||||
_start_local_worker(worker_manager, worker_params)
|
root_operation_name="DB-GPT-WorkerManager-Entry",
|
||||||
_start_local_embedding_worker(
|
)
|
||||||
worker_manager, embedding_model_name, embedding_model_path
|
_start_local_worker(worker_manager, worker_params)
|
||||||
)
|
_start_local_embedding_worker(
|
||||||
loop = asyncio.get_event_loop()
|
worker_manager, embedding_model_name, embedding_model_path
|
||||||
loop.run_until_complete(worker_manager.start())
|
)
|
||||||
|
|
||||||
if include_router:
|
if include_router:
|
||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
||||||
@ -946,6 +978,8 @@ def run_worker_manager(
|
|||||||
if not embedded_mod:
|
if not embedded_mod:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(worker_manager.start())
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
||||||
)
|
)
|
||||||
|
@ -46,7 +46,7 @@ from pilot.summary.db_summary_client import DBSummaryClient
|
|||||||
|
|
||||||
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||||
from pilot.model.base import FlatSupportedModel
|
from pilot.model.base import FlatSupportedModel
|
||||||
from pilot.utils.tracer import root_tracer
|
from pilot.utils.tracer import root_tracer, SpanType
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -367,7 +367,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
print(
|
print(
|
||||||
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
|
||||||
)
|
)
|
||||||
with root_tracer.start_span("chat_completions", metadata=dialogue.dict()) as _:
|
with root_tracer.start_span(
|
||||||
|
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||||
|
):
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
# background_tasks = BackgroundTasks()
|
# background_tasks = BackgroundTasks()
|
||||||
# background_tasks.add_task(release_model_semaphore)
|
# background_tasks.add_task(release_model_semaphore)
|
||||||
@ -419,11 +421,10 @@ async def model_supports(worker_manager: WorkerManager = Depends(get_worker_mana
|
|||||||
|
|
||||||
|
|
||||||
async def no_stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
span = root_tracer.start_span("no_stream_generator")
|
with root_tracer.start_span("no_stream_generator"):
|
||||||
msg = await chat.nostream_call()
|
msg = await chat.nostream_call()
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
yield f"data: {msg}\n\n"
|
yield f"data: {msg}\n\n"
|
||||||
span.end()
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(chat, incremental: bool, model_name: str):
|
async def stream_generator(chat, incremental: bool, model_name: str):
|
||||||
@ -467,9 +468,9 @@ async def stream_generator(chat, incremental: bool, model_name: str):
|
|||||||
yield f"data:{msg}\n\n"
|
yield f"data:{msg}\n\n"
|
||||||
previous_response = msg
|
previous_response = msg
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
span.end()
|
|
||||||
if incremental:
|
if incremental:
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
span.end()
|
||||||
chat.current_message.add_ai_message(msg)
|
chat.current_message.add_ai_message(msg)
|
||||||
chat.current_message.add_view_message(msg)
|
chat.current_message.add_view_message(msg)
|
||||||
chat.memory.append(chat.current_message)
|
chat.memory.append(chat.current_message)
|
||||||
|
@ -139,7 +139,9 @@ class BaseChat(ABC):
|
|||||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||||
metadata = {k: v for k, v in payload.items()}
|
metadata = {k: v for k, v in payload.items()}
|
||||||
del metadata["prompt"]
|
del metadata["prompt"]
|
||||||
metadata["messages"] = list(map(lambda m: m.dict(), metadata["messages"]))
|
metadata["messages"] = list(
|
||||||
|
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||||
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
@ -152,6 +154,7 @@ class BaseChat(ABC):
|
|||||||
span = root_tracer.start_span(
|
span = root_tracer.start_span(
|
||||||
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
||||||
)
|
)
|
||||||
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
@ -178,6 +181,7 @@ class BaseChat(ABC):
|
|||||||
span = root_tracer.start_span(
|
span = root_tracer.start_span(
|
||||||
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
||||||
)
|
)
|
||||||
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
@ -185,7 +189,7 @@ class BaseChat(ABC):
|
|||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate") as _:
|
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await worker_manager.generate(payload)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
@ -206,7 +210,7 @@ class BaseChat(ABC):
|
|||||||
"ai_response_text": ai_response_text,
|
"ai_response_text": ai_response_text,
|
||||||
"prompt_define_response": prompt_define_response,
|
"prompt_define_response": prompt_define_response,
|
||||||
}
|
}
|
||||||
with root_tracer.start_span("BaseChat.do_action", metadata=metadata) as _:
|
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
||||||
### run
|
### run
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
|
@ -119,6 +119,14 @@ except ImportError as e:
|
|||||||
logging.warning(f"Integrating dbgpt knowledge command line tool failed: {e}")
|
logging.warning(f"Integrating dbgpt knowledge command line tool failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pilot.utils.tracer.tracer_cli import trace_cli_group
|
||||||
|
|
||||||
|
add_command_alias(trace_cli_group, name="trace", parent_group=cli)
|
||||||
|
except ImportError as e:
|
||||||
|
logging.warning(f"Integrating dbgpt trace command line tool failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
return cli()
|
return cli()
|
||||||
|
|
||||||
|
@ -8,8 +8,6 @@ from pilot.component import ComponentType, SystemApp
|
|||||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils.tracer import root_tracer, initialize_tracer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -26,8 +24,6 @@ def initialize_components(
|
|||||||
):
|
):
|
||||||
from pilot.model.cluster.controller.controller import controller
|
from pilot.model.cluster.controller.controller import controller
|
||||||
|
|
||||||
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
|
|
||||||
|
|
||||||
# Register global default executor factory first
|
# Register global default executor factory first
|
||||||
system_app.register(DefaultExecutorFactory)
|
system_app.register(DefaultExecutorFactory)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
|||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
import signal
|
import signal
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp
|
||||||
|
|
||||||
from pilot.server.base import (
|
from pilot.server.base import (
|
||||||
@ -38,6 +38,8 @@ from pilot.utils.utils import (
|
|||||||
_get_logging_level,
|
_get_logging_level,
|
||||||
logging_str_to_uvicorn_level,
|
logging_str_to_uvicorn_level,
|
||||||
)
|
)
|
||||||
|
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
|
||||||
|
from pilot.utils.parameter_utils import _get_dict_from_obj
|
||||||
|
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
|
||||||
@ -98,17 +100,21 @@ def mount_static_files(app):
|
|||||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_webserver_params(args: List[str] = None):
|
||||||
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
|
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
|
||||||
|
WebWerverParameters
|
||||||
|
)
|
||||||
|
return WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||||
|
|
||||||
|
|
||||||
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||||
"""Initialize app
|
"""Initialize app
|
||||||
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
|
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
|
||||||
"""
|
"""
|
||||||
if not param:
|
if not param:
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
param = _get_webserver_params(args)
|
||||||
|
|
||||||
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
|
|
||||||
WebWerverParameters
|
|
||||||
)
|
|
||||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
|
||||||
|
|
||||||
if not param.log_level:
|
if not param.log_level:
|
||||||
param.log_level = _get_logging_level()
|
param.log_level = _get_logging_level()
|
||||||
@ -127,7 +133,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
model_start_listener = _create_model_start_listener(system_app)
|
model_start_listener = _create_model_start_listener(system_app)
|
||||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||||
if not param.light:
|
if not param.light:
|
||||||
print("Model Unified Deployment Mode!")
|
print("Model Unified Deployment Mode!")
|
||||||
if not param.remote_embedding:
|
if not param.remote_embedding:
|
||||||
@ -174,8 +180,20 @@ def run_uvicorn(param: WebWerverParameters):
|
|||||||
|
|
||||||
|
|
||||||
def run_webserver(param: WebWerverParameters = None):
|
def run_webserver(param: WebWerverParameters = None):
|
||||||
param = initialize_app(param)
|
if not param:
|
||||||
run_uvicorn(param)
|
param = _get_webserver_params()
|
||||||
|
initialize_tracer(system_app, os.path.join(LOGDIR, "dbgpt_webserver_tracer.jsonl"))
|
||||||
|
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"run_webserver",
|
||||||
|
span_type=SpanType.RUN,
|
||||||
|
metadata={
|
||||||
|
"run_service": SpanTypeRunName.WEBSERVER,
|
||||||
|
"params": _get_dict_from_obj(param),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
param = initialize_app(param)
|
||||||
|
run_uvicorn(param)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -13,7 +13,7 @@ from pilot.model.cluster import run_worker_manager
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_worker_manager(
|
run_worker_manager(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, fields, MISSING, asdict, field
|
from dataclasses import dataclass, fields, MISSING, asdict, field, is_dataclass
|
||||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@ -590,6 +590,20 @@ def _extract_parameter_details(
|
|||||||
return descriptions
|
return descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]:
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
if is_dataclass(type(obj)):
|
||||||
|
params = {}
|
||||||
|
for field_info in fields(obj):
|
||||||
|
value = _get_simple_privacy_field_value(obj, field_info)
|
||||||
|
params[field_info.name] = value
|
||||||
|
return params
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
class _SimpleArgParser:
|
class _SimpleArgParser:
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from pilot.utils.tracer.base import (
|
from pilot.utils.tracer.base import (
|
||||||
|
SpanType,
|
||||||
Span,
|
Span,
|
||||||
|
SpanTypeRunName,
|
||||||
Tracer,
|
Tracer,
|
||||||
SpanStorage,
|
SpanStorage,
|
||||||
SpanStorageType,
|
SpanStorageType,
|
||||||
@ -14,7 +16,9 @@ from pilot.utils.tracer.tracer_impl import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"SpanType",
|
||||||
"Span",
|
"Span",
|
||||||
|
"SpanTypeRunName",
|
||||||
"Tracer",
|
"Tracer",
|
||||||
"SpanStorage",
|
"SpanStorage",
|
||||||
"SpanStorageType",
|
"SpanStorageType",
|
||||||
|
@ -10,22 +10,41 @@ from datetime import datetime
|
|||||||
from pilot.component import BaseComponent, SystemApp, ComponentType
|
from pilot.component import BaseComponent, SystemApp, ComponentType
|
||||||
|
|
||||||
|
|
||||||
|
class SpanType(str, Enum):
|
||||||
|
BASE = "base"
|
||||||
|
RUN = "run"
|
||||||
|
CHAT = "chat"
|
||||||
|
|
||||||
|
|
||||||
|
class SpanTypeRunName(str, Enum):
|
||||||
|
WEBSERVER = "Webserver"
|
||||||
|
WORKER_MANAGER = "WorkerManager"
|
||||||
|
MODEL_WORKER = "ModelWorker"
|
||||||
|
EMBEDDING_MODEL = "EmbeddingModel"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def values():
|
||||||
|
return [item.value for item in SpanTypeRunName]
|
||||||
|
|
||||||
|
|
||||||
class Span:
|
class Span:
|
||||||
"""Represents a unit of work that is being traced.
|
"""Represents a unit of work that is being traced.
|
||||||
This can be any operation like a function call or a database query.
|
This can be any operation like a function call or a database query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
span_type: str = "base"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trace_id: str,
|
trace_id: str,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
|
span_type: SpanType = None,
|
||||||
parent_span_id: str = None,
|
parent_span_id: str = None,
|
||||||
operation_name: str = None,
|
operation_name: str = None,
|
||||||
metadata: Dict = None,
|
metadata: Dict = None,
|
||||||
end_caller: Callable[[Span], None] = None,
|
end_caller: Callable[[Span], None] = None,
|
||||||
):
|
):
|
||||||
|
if not span_type:
|
||||||
|
span_type = SpanType.BASE
|
||||||
|
self.span_type = span_type
|
||||||
# The unique identifier for the entire trace
|
# The unique identifier for the entire trace
|
||||||
self.trace_id = trace_id
|
self.trace_id = trace_id
|
||||||
# Unique identifier for this span within the trace
|
# Unique identifier for this span within the trace
|
||||||
@ -65,7 +84,7 @@ class Span:
|
|||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
return {
|
return {
|
||||||
"span_type": self.span_type,
|
"span_type": self.span_type.value,
|
||||||
"trace_id": self.trace_id,
|
"trace_id": self.trace_id,
|
||||||
"span_id": self.span_id,
|
"span_id": self.span_id,
|
||||||
"parent_span_id": self.parent_span_id,
|
"parent_span_id": self.parent_span_id,
|
||||||
@ -124,7 +143,11 @@ class Tracer(BaseComponent, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start_span(
|
def start_span(
|
||||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
self,
|
||||||
|
operation_name: str,
|
||||||
|
parent_span_id: str = None,
|
||||||
|
span_type: SpanType = None,
|
||||||
|
metadata: Dict = None,
|
||||||
) -> Span:
|
) -> Span:
|
||||||
"""Begin a new span for the given operation. If provided, the span will be
|
"""Begin a new span for the given operation. If provided, the span will be
|
||||||
a child of the span with the given parent_span_id.
|
a child of the span with the given parent_span_id.
|
||||||
@ -158,4 +181,4 @@ class Tracer(BaseComponent, ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TracerContext:
|
class TracerContext:
|
||||||
span_id: str
|
span_id: Optional[str] = None
|
||||||
|
540
pilot/utils/tracer/tracer_cli.py
Normal file
540
pilot/utils/tracer/tracer_cli.py
Normal file
@ -0,0 +1,540 @@
|
|||||||
|
import os
|
||||||
|
import click
|
||||||
|
import logging
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Iterable, Dict, Callable
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils.tracer import SpanType, SpanTypeRunName
|
||||||
|
|
||||||
|
logger = logging.getLogger("dbgpt_cli")
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_FILE_PATTERN = os.path.join(LOGDIR, "dbgpt*.jsonl")
|
||||||
|
|
||||||
|
|
||||||
|
@click.group("trace")
|
||||||
|
def trace_cli_group():
|
||||||
|
"""Analyze and visualize trace spans."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@trace_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--trace_id",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Specify the trace ID to list",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--span_id",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Specify the Span ID to list.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--span_type",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Specify the Span Type to list.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--parent_span_id",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Specify the Parent Span ID to list.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--search",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Search trace_id, span_id, parent_span_id, operation_name or content in metadata.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-l",
|
||||||
|
"--limit",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Limit the number of recent span displayed.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--start_time",
|
||||||
|
type=str,
|
||||||
|
help='Filter by start time. Format: "YYYY-MM-DD HH:MM:SS.mmm"',
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--end_time", type=str, help='Filter by end time. Format: "YYYY-MM-DD HH:MM:SS.mmm"'
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--desc",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Whether to use reverse sorting. By default, sorting is based on start time.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
required=False,
|
||||||
|
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||||
|
default="text",
|
||||||
|
help="The output format",
|
||||||
|
)
|
||||||
|
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||||
|
def list(
|
||||||
|
trace_id: str,
|
||||||
|
span_id: str,
|
||||||
|
span_type: str,
|
||||||
|
parent_span_id: str,
|
||||||
|
search: str,
|
||||||
|
limit: int,
|
||||||
|
start_time: str,
|
||||||
|
end_time: str,
|
||||||
|
desc: bool,
|
||||||
|
output: str,
|
||||||
|
files=None,
|
||||||
|
):
|
||||||
|
"""List your trace spans"""
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
# If no files are explicitly specified, use the default pattern to get them
|
||||||
|
spans = read_spans_from_files(files)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
spans = filter(lambda s: s["trace_id"] == trace_id, spans)
|
||||||
|
if span_id:
|
||||||
|
spans = filter(lambda s: s["span_id"] == span_id, spans)
|
||||||
|
if span_type:
|
||||||
|
spans = filter(lambda s: s["span_type"] == span_type, spans)
|
||||||
|
if parent_span_id:
|
||||||
|
spans = filter(lambda s: s["parent_span_id"] == parent_span_id, spans)
|
||||||
|
# Filter spans based on the start and end times
|
||||||
|
if start_time:
|
||||||
|
start_dt = _parse_datetime(start_time)
|
||||||
|
spans = filter(
|
||||||
|
lambda span: _parse_datetime(span["start_time"]) >= start_dt, spans
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
end_dt = _parse_datetime(end_time)
|
||||||
|
spans = filter(
|
||||||
|
lambda span: _parse_datetime(span["start_time"]) <= end_dt, spans
|
||||||
|
)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
spans = filter(_new_search_span_func(search), spans)
|
||||||
|
|
||||||
|
# Sort spans based on the start time
|
||||||
|
spans = sorted(
|
||||||
|
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
|
||||||
|
)[:limit]
|
||||||
|
|
||||||
|
table = PrettyTable(
|
||||||
|
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for sp in spans:
|
||||||
|
conv_uid = None
|
||||||
|
if "metadata" in sp and sp:
|
||||||
|
metadata = sp["metadata"]
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
conv_uid = metadata.get("conv_uid")
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
sp.get("trace_id"),
|
||||||
|
sp.get("span_id"),
|
||||||
|
# sp.get("parent_span_id"),
|
||||||
|
sp.get("operation_name"),
|
||||||
|
conv_uid,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||||
|
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@trace_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--trace_id",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Specify the trace ID to list",
|
||||||
|
)
|
||||||
|
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||||
|
def tree(trace_id: str, files):
|
||||||
|
"""Display trace links as a tree"""
|
||||||
|
hierarchy = _view_trace_hierarchy(trace_id, files)
|
||||||
|
_print_trace_hierarchy(hierarchy)
|
||||||
|
|
||||||
|
|
||||||
|
@trace_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--trace_id",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specify the trace ID to analyze. If None, show latest conversation details",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--tree",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Display trace spans as a tree",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--hide_run_params",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Hide run params",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
required=False,
|
||||||
|
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||||
|
default="text",
|
||||||
|
help="The output format",
|
||||||
|
)
|
||||||
|
@click.argument("files", nargs=-1, type=click.Path(exists=False, readable=True))
|
||||||
|
def chat(trace_id: str, tree: bool, hide_run_params: bool, output: str, files):
|
||||||
|
"""Show conversation details"""
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
spans = read_spans_from_files(files)
|
||||||
|
|
||||||
|
# Sort by start time
|
||||||
|
spans = sorted(
|
||||||
|
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=True
|
||||||
|
)
|
||||||
|
spans = [sp for sp in spans]
|
||||||
|
service_spans = {}
|
||||||
|
service_names = set(SpanTypeRunName.values())
|
||||||
|
found_trace_id = None
|
||||||
|
for sp in spans:
|
||||||
|
span_type = sp["span_type"]
|
||||||
|
metadata = sp.get("metadata")
|
||||||
|
if span_type == SpanType.RUN:
|
||||||
|
service_name = metadata["run_service"]
|
||||||
|
service_spans[service_name] = sp.copy()
|
||||||
|
if set(service_spans.keys()) == service_names and found_trace_id:
|
||||||
|
break
|
||||||
|
elif span_type == SpanType.CHAT and not found_trace_id:
|
||||||
|
if not trace_id:
|
||||||
|
found_trace_id = sp["trace_id"]
|
||||||
|
if trace_id and trace_id == sp["trace_id"]:
|
||||||
|
found_trace_id = trace_id
|
||||||
|
|
||||||
|
service_tables = {}
|
||||||
|
out_kwargs = {"ensure_ascii": False} if output == "json" else {}
|
||||||
|
for service_name, sp in service_spans.items():
|
||||||
|
metadata = sp["metadata"]
|
||||||
|
table = PrettyTable(["Config Key", "Config Value"], title=service_name)
|
||||||
|
for k, v in metadata["params"].items():
|
||||||
|
table.add_row([k, v])
|
||||||
|
service_tables[service_name] = table
|
||||||
|
|
||||||
|
if not hide_run_params:
|
||||||
|
merged_table1 = merge_tables_horizontally(
|
||||||
|
[
|
||||||
|
service_tables.get(SpanTypeRunName.WEBSERVER.value),
|
||||||
|
service_tables.get(SpanTypeRunName.EMBEDDING_MODEL.value),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
merged_table2 = merge_tables_horizontally(
|
||||||
|
[
|
||||||
|
service_tables.get(SpanTypeRunName.MODEL_WORKER),
|
||||||
|
service_tables.get(SpanTypeRunName.WORKER_MANAGER),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if output == "text":
|
||||||
|
print(merged_table1)
|
||||||
|
print(merged_table2)
|
||||||
|
else:
|
||||||
|
for service_name, table in service_tables.items():
|
||||||
|
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||||
|
|
||||||
|
if not found_trace_id:
|
||||||
|
print(f"Can't found conversation with trace_id: {trace_id}")
|
||||||
|
return
|
||||||
|
trace_id = found_trace_id
|
||||||
|
|
||||||
|
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||||
|
trace_spans = [s for s in reversed(trace_spans)]
|
||||||
|
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||||
|
if tree:
|
||||||
|
print("\nInvoke Trace Tree:\n")
|
||||||
|
_print_trace_hierarchy(hierarchy)
|
||||||
|
|
||||||
|
trace_spans = _get_ordered_trace_from(hierarchy)
|
||||||
|
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
||||||
|
split_long_text = output == "text"
|
||||||
|
|
||||||
|
for sp in trace_spans:
|
||||||
|
op = sp["operation_name"]
|
||||||
|
metadata = sp.get("metadata")
|
||||||
|
if op == "get_chat_instance" and not sp["end_time"]:
|
||||||
|
table.add_row(["trace_id", trace_id])
|
||||||
|
table.add_row(["span_id", sp["span_id"]])
|
||||||
|
table.add_row(["conv_uid", metadata.get("conv_uid")])
|
||||||
|
table.add_row(["user_input", metadata.get("user_input")])
|
||||||
|
table.add_row(["chat_mode", metadata.get("chat_mode")])
|
||||||
|
table.add_row(["select_param", metadata.get("select_param")])
|
||||||
|
table.add_row(["model_name", metadata.get("model_name")])
|
||||||
|
if op in ["BaseChat.stream_call", "BaseChat.nostream_call"]:
|
||||||
|
if not sp["end_time"]:
|
||||||
|
table.add_row(["temperature", metadata.get("temperature")])
|
||||||
|
table.add_row(["max_new_tokens", metadata.get("max_new_tokens")])
|
||||||
|
table.add_row(["echo", metadata.get("echo")])
|
||||||
|
elif "error" in metadata:
|
||||||
|
table.add_row(["BaseChat Error", metadata.get("error")])
|
||||||
|
if op == "BaseChat.nostream_call" and not sp["end_time"]:
|
||||||
|
if "model_output" in metadata:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
"BaseChat model_output",
|
||||||
|
split_string_by_terminal_width(
|
||||||
|
metadata.get("model_output").get("text"),
|
||||||
|
split=split_long_text,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if "ai_response_text" in metadata:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
"BaseChat ai_response_text",
|
||||||
|
split_string_by_terminal_width(
|
||||||
|
metadata.get("ai_response_text"), split=split_long_text
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if "prompt_define_response" in metadata:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
"BaseChat prompt_define_response",
|
||||||
|
split_string_by_terminal_width(
|
||||||
|
metadata.get("prompt_define_response"),
|
||||||
|
split=split_long_text,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if op == "DefaultModelWorker_call.generate_stream_func":
|
||||||
|
if not sp["end_time"]:
|
||||||
|
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
"User prompt",
|
||||||
|
split_string_by_terminal_width(
|
||||||
|
metadata.get("prompt"), split=split_long_text
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
"Model output",
|
||||||
|
split_string_by_terminal_width(metadata.get("output")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
op
|
||||||
|
in [
|
||||||
|
"DefaultModelWorker.async_generate_stream",
|
||||||
|
"DefaultModelWorker.generate_stream",
|
||||||
|
]
|
||||||
|
and metadata
|
||||||
|
and "error" in metadata
|
||||||
|
):
|
||||||
|
table.add_row(["Model Error", metadata.get("error")])
|
||||||
|
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def read_spans_from_files(files=None) -> Iterable[Dict]:
|
||||||
|
"""
|
||||||
|
Reads spans from multiple files based on the provided file paths.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
files = [_DEFAULT_FILE_PATTERN]
|
||||||
|
|
||||||
|
for filepath in files:
|
||||||
|
for filename in glob.glob(filepath):
|
||||||
|
with open(filename, "r") as file:
|
||||||
|
for line in file:
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
|
def _new_search_span_func(search: str):
|
||||||
|
def func(span: Dict) -> bool:
|
||||||
|
items = [span["trace_id"], span["span_id"], span["parent_span_id"]]
|
||||||
|
if "operation_name" in span:
|
||||||
|
items.append(span["operation_name"])
|
||||||
|
if "metadata" in span:
|
||||||
|
metadata = span["metadata"]
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
for k, v in metadata.items():
|
||||||
|
items.append(k)
|
||||||
|
items.append(v)
|
||||||
|
return any(search in str(item) for item in items if item)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_datetime(dt_str):
|
||||||
|
"""Parse a datetime string to a datetime object."""
|
||||||
|
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_trace_hierarchy(spans, parent_span_id=None, indent=0):
|
||||||
|
# Current spans
|
||||||
|
current_level_spans = [
|
||||||
|
span
|
||||||
|
for span in spans
|
||||||
|
if span["parent_span_id"] == parent_span_id and span["end_time"] is None
|
||||||
|
]
|
||||||
|
|
||||||
|
hierarchy = []
|
||||||
|
|
||||||
|
for start_span in current_level_spans:
|
||||||
|
# Find end span
|
||||||
|
end_span = next(
|
||||||
|
(
|
||||||
|
span
|
||||||
|
for span in spans
|
||||||
|
if span["span_id"] == start_span["span_id"]
|
||||||
|
and span["end_time"] is not None
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
entry = {
|
||||||
|
"operation_name": start_span["operation_name"],
|
||||||
|
"parent_span_id": start_span["parent_span_id"],
|
||||||
|
"span_id": start_span["span_id"],
|
||||||
|
"start_time": start_span["start_time"],
|
||||||
|
"end_time": start_span["end_time"],
|
||||||
|
"metadata": start_span["metadata"],
|
||||||
|
"children": _build_trace_hierarchy(
|
||||||
|
spans, start_span["span_id"], indent + 1
|
||||||
|
),
|
||||||
|
}
|
||||||
|
hierarchy.append(entry)
|
||||||
|
|
||||||
|
# Append end span
|
||||||
|
if end_span:
|
||||||
|
entry_end = {
|
||||||
|
"operation_name": end_span["operation_name"],
|
||||||
|
"parent_span_id": end_span["parent_span_id"],
|
||||||
|
"span_id": end_span["span_id"],
|
||||||
|
"start_time": end_span["start_time"],
|
||||||
|
"end_time": end_span["end_time"],
|
||||||
|
"metadata": end_span["metadata"],
|
||||||
|
"children": [],
|
||||||
|
}
|
||||||
|
hierarchy.append(entry_end)
|
||||||
|
|
||||||
|
return hierarchy
|
||||||
|
|
||||||
|
|
||||||
|
def _view_trace_hierarchy(trace_id, files=None):
|
||||||
|
"""Find and display the calls of the entire link based on the given trace_id"""
|
||||||
|
spans = read_spans_from_files(files)
|
||||||
|
trace_spans = [span for span in spans if span["trace_id"] == trace_id]
|
||||||
|
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||||
|
return hierarchy
|
||||||
|
|
||||||
|
|
||||||
|
def _print_trace_hierarchy(hierarchy, indent=0):
|
||||||
|
"""Print link hierarchy"""
|
||||||
|
for entry in hierarchy:
|
||||||
|
print(
|
||||||
|
" " * indent
|
||||||
|
+ f"Operation: {entry['operation_name']} (Start: {entry['start_time']}, End: {entry['end_time']})"
|
||||||
|
)
|
||||||
|
_print_trace_hierarchy(entry["children"], indent + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ordered_trace_from(hierarchy):
|
||||||
|
traces = []
|
||||||
|
|
||||||
|
def func(items):
|
||||||
|
for item in items:
|
||||||
|
traces.append(item)
|
||||||
|
func(item["children"])
|
||||||
|
|
||||||
|
func(hierarchy)
|
||||||
|
return traces
|
||||||
|
|
||||||
|
|
||||||
|
def _print(service_spans: Dict):
|
||||||
|
for names in [
|
||||||
|
[SpanTypeRunName.WEBSERVER.name, SpanTypeRunName.EMBEDDING_MODEL],
|
||||||
|
[SpanTypeRunName.WORKER_MANAGER.name, SpanTypeRunName.MODEL_WORKER],
|
||||||
|
]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def merge_tables_horizontally(tables):
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
if not tables:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tables = [t for t in tables if t]
|
||||||
|
if not tables:
|
||||||
|
return None
|
||||||
|
|
||||||
|
max_rows = max(len(table._rows) for table in tables)
|
||||||
|
|
||||||
|
merged_table = PrettyTable()
|
||||||
|
|
||||||
|
new_field_names = []
|
||||||
|
for table in tables:
|
||||||
|
new_field_names.extend(
|
||||||
|
[
|
||||||
|
f"{name} ({table.title})" if table.title else f"{name}"
|
||||||
|
for name in table.field_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
merged_table.field_names = new_field_names
|
||||||
|
|
||||||
|
for i in range(max_rows):
|
||||||
|
merged_row = []
|
||||||
|
for table in tables:
|
||||||
|
if i < len(table._rows):
|
||||||
|
merged_row.extend(table._rows[i])
|
||||||
|
else:
|
||||||
|
# Fill empty cells for shorter tables
|
||||||
|
merged_row.extend([""] * len(table.field_names))
|
||||||
|
merged_table.add_row(merged_row)
|
||||||
|
|
||||||
|
return merged_table
|
||||||
|
|
||||||
|
|
||||||
|
def split_string_by_terminal_width(s, split=True, max_len=None, sp="\n"):
|
||||||
|
"""
|
||||||
|
Split a string into substrings based on the current terminal width.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- s: the input string
|
||||||
|
"""
|
||||||
|
if not split:
|
||||||
|
return s
|
||||||
|
if not max_len:
|
||||||
|
try:
|
||||||
|
max_len = int(os.get_terminal_size().columns * 0.8)
|
||||||
|
except OSError:
|
||||||
|
# Default to 80 columns if the terminal size can't be determined
|
||||||
|
max_len = 100
|
||||||
|
return sp.join([s[i : i + max_len] for i in range(0, len(s), max_len)])
|
@ -4,6 +4,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from pilot.component import SystemApp, ComponentType
|
from pilot.component import SystemApp, ComponentType
|
||||||
from pilot.utils.tracer.base import (
|
from pilot.utils.tracer.base import (
|
||||||
|
SpanType,
|
||||||
Span,
|
Span,
|
||||||
Tracer,
|
Tracer,
|
||||||
SpanStorage,
|
SpanStorage,
|
||||||
@ -32,14 +33,23 @@ class DefaultTracer(Tracer):
|
|||||||
self._get_current_storage().append_span(span)
|
self._get_current_storage().append_span(span)
|
||||||
|
|
||||||
def start_span(
|
def start_span(
|
||||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
self,
|
||||||
|
operation_name: str,
|
||||||
|
parent_span_id: str = None,
|
||||||
|
span_type: SpanType = None,
|
||||||
|
metadata: Dict = None,
|
||||||
) -> Span:
|
) -> Span:
|
||||||
trace_id = (
|
trace_id = (
|
||||||
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
|
self._new_uuid() if parent_span_id is None else parent_span_id.split(":")[0]
|
||||||
)
|
)
|
||||||
span_id = f"{trace_id}:{self._new_uuid()}"
|
span_id = f"{trace_id}:{self._new_uuid()}"
|
||||||
span = Span(
|
span = Span(
|
||||||
trace_id, span_id, parent_span_id, operation_name, metadata=metadata
|
trace_id,
|
||||||
|
span_id,
|
||||||
|
span_type,
|
||||||
|
parent_span_id,
|
||||||
|
operation_name,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._span_storage_type in [
|
if self._span_storage_type in [
|
||||||
@ -81,11 +91,13 @@ class DefaultTracer(Tracer):
|
|||||||
|
|
||||||
|
|
||||||
class TracerManager:
|
class TracerManager:
|
||||||
|
"""The manager of current tracer"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._system_app: Optional[SystemApp] = None
|
self._system_app: Optional[SystemApp] = None
|
||||||
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
|
self._trace_context_var: ContextVar[TracerContext] = ContextVar(
|
||||||
"trace_context",
|
"trace_context",
|
||||||
default=TracerContext(span_id="default_trace_id:default_span_id"),
|
default=TracerContext(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
@ -101,18 +113,23 @@ class TracerManager:
|
|||||||
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
|
return self._system_app.get_component(ComponentType.TRACER, Tracer, None)
|
||||||
|
|
||||||
def start_span(
|
def start_span(
|
||||||
self, operation_name: str, parent_span_id: str = None, metadata: Dict = None
|
self,
|
||||||
|
operation_name: str,
|
||||||
|
parent_span_id: str = None,
|
||||||
|
span_type: SpanType = None,
|
||||||
|
metadata: Dict = None,
|
||||||
) -> Span:
|
) -> Span:
|
||||||
|
"""Start a new span with operation_name
|
||||||
|
This method must not throw an exception under any case and try not to block as much as possible
|
||||||
|
"""
|
||||||
tracer = self._get_tracer()
|
tracer = self._get_tracer()
|
||||||
if not tracer:
|
if not tracer:
|
||||||
return Span("empty_span", "empty_span")
|
return Span("empty_span", "empty_span")
|
||||||
if not parent_span_id:
|
if not parent_span_id:
|
||||||
current_span = self.get_current_span()
|
parent_span_id = self.get_current_span_id()
|
||||||
if current_span:
|
return tracer.start_span(
|
||||||
parent_span_id = current_span.span_id
|
operation_name, parent_span_id, span_type=span_type, metadata=metadata
|
||||||
else:
|
)
|
||||||
parent_span_id = self._trace_context_var.get().span_id
|
|
||||||
return tracer.start_span(operation_name, parent_span_id, metadata)
|
|
||||||
|
|
||||||
def end_span(self, span: Span, **kwargs):
|
def end_span(self, span: Span, **kwargs):
|
||||||
tracer = self._get_tracer()
|
tracer = self._get_tracer()
|
||||||
@ -126,15 +143,22 @@ class TracerManager:
|
|||||||
return None
|
return None
|
||||||
return tracer.get_current_span()
|
return tracer.get_current_span()
|
||||||
|
|
||||||
|
def get_current_span_id(self) -> Optional[str]:
|
||||||
|
current_span = self.get_current_span()
|
||||||
|
if current_span:
|
||||||
|
return current_span.span_id
|
||||||
|
ctx = self._trace_context_var.get()
|
||||||
|
return ctx.span_id if ctx else None
|
||||||
|
|
||||||
|
|
||||||
root_tracer: TracerManager = TracerManager()
|
root_tracer: TracerManager = TracerManager()
|
||||||
|
|
||||||
|
|
||||||
def trace(operation_name: str):
|
def trace(operation_name: str, **trace_kwargs):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
with root_tracer.start_span(operation_name) as span:
|
with root_tracer.start_span(operation_name, **trace_kwargs):
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@ -142,14 +166,18 @@ def trace(operation_name: str):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def initialize_tracer(system_app: SystemApp, tracer_filename: str):
|
def initialize_tracer(
|
||||||
|
system_app: SystemApp,
|
||||||
|
tracer_filename: str,
|
||||||
|
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||||
|
):
|
||||||
if not system_app:
|
if not system_app:
|
||||||
return
|
return
|
||||||
from pilot.utils.tracer.span_storage import FileSpanStorage
|
from pilot.utils.tracer.span_storage import FileSpanStorage
|
||||||
|
|
||||||
trace_context_var = ContextVar(
|
trace_context_var = ContextVar(
|
||||||
"trace_context",
|
"trace_context",
|
||||||
default=TracerContext(span_id="default_trace_id:default_span_id"),
|
default=TracerContext(),
|
||||||
)
|
)
|
||||||
tracer = DefaultTracer(system_app)
|
tracer = DefaultTracer(system_app)
|
||||||
|
|
||||||
@ -160,5 +188,8 @@ def initialize_tracer(system_app: SystemApp, tracer_filename: str):
|
|||||||
from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware
|
from pilot.utils.tracer.tracer_middleware import TraceIDMiddleware
|
||||||
|
|
||||||
system_app.app.add_middleware(
|
system_app.app.add_middleware(
|
||||||
TraceIDMiddleware, trace_context_var=trace_context_var, tracer=tracer
|
TraceIDMiddleware,
|
||||||
|
trace_context_var=trace_context_var,
|
||||||
|
tracer=tracer,
|
||||||
|
root_operation_name=root_operation_name,
|
||||||
)
|
)
|
||||||
|
@ -16,12 +16,14 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
|
|||||||
app: ASGIApp,
|
app: ASGIApp,
|
||||||
trace_context_var: ContextVar[TracerContext],
|
trace_context_var: ContextVar[TracerContext],
|
||||||
tracer: Tracer,
|
tracer: Tracer,
|
||||||
|
root_operation_name: str = "DB-GPT-Web-Entry",
|
||||||
include_prefix: str = "/api",
|
include_prefix: str = "/api",
|
||||||
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
|
exclude_paths=_DEFAULT_EXCLUDE_PATHS,
|
||||||
):
|
):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.trace_context_var = trace_context_var
|
self.trace_context_var = trace_context_var
|
||||||
self.tracer = tracer
|
self.tracer = tracer
|
||||||
|
self.root_operation_name = root_operation_name
|
||||||
self.include_prefix = include_prefix
|
self.include_prefix = include_prefix
|
||||||
self.exclude_paths = exclude_paths
|
self.exclude_paths = exclude_paths
|
||||||
|
|
||||||
@ -37,7 +39,7 @@ class TraceIDMiddleware(BaseHTTPMiddleware):
|
|||||||
# self.trace_context_var.set(TracerContext(span_id=span_id))
|
# self.trace_context_var.set(TracerContext(span_id=span_id))
|
||||||
|
|
||||||
with self.tracer.start_span(
|
with self.tracer.start_span(
|
||||||
"DB-GPT-Web-Entry", span_id, metadata={"path": request.url.path}
|
self.root_operation_name, span_id, metadata={"path": request.url.path}
|
||||||
) as _:
|
):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
Loading…
Reference in New Issue
Block a user