mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
feat(core): More trace records for DB-GPT (#775)
- More trace record for DB-GPT - Support pass span id to threadpool
This commit is contained in:
commit
1d2b054372
@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
|
||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||
`result` TEXT NULL COMMENT 'knowledge content',
|
||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
|
@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
from pydantic import Extra
|
||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
@ -38,6 +39,7 @@ class BaseChat(ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@trace("BaseChat.__init__")
|
||||
def __init__(self, chat_param: Dict):
|
||||
"""Chat Module Initialization
|
||||
Args:
|
||||
@ -143,7 +145,14 @@ class BaseChat(ABC):
|
||||
)
|
||||
self.current_message.tokens = 0
|
||||
if self.prompt_template.template:
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
metadata = {
|
||||
"template_scene": self.prompt_template.template_scene,
|
||||
"input_values": input_values,
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||
):
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
llm_messages = self.generate_llm_messages()
|
||||
@ -175,6 +184,14 @@ class BaseChat(ABC):
|
||||
except StopAsyncIteration:
|
||||
return True # 迭代器已经执行结束
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
del metadata["prompt"]
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
async def stream_call(self):
|
||||
# TODO Retry when server connection error
|
||||
payload = await self.__call_base()
|
||||
@ -182,6 +199,10 @@ class BaseChat(ABC):
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
@ -199,6 +220,7 @@ class BaseChat(ABC):
|
||||
self.current_message.add_ai_message(msg)
|
||||
view_msg = self.knowledge_reference_call(msg)
|
||||
self.current_message.add_view_message(view_msg)
|
||||
span.end()
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
@ -206,12 +228,17 @@ class BaseChat(ABC):
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
### store current conversation
|
||||
span.end(metadata={"error": str(e)})
|
||||
self.memory.append(self.current_message)
|
||||
|
||||
async def nostream_call(self):
|
||||
payload = await self.__call_base()
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
span = root_tracer.start_span(
|
||||
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
try:
|
||||
from pilot.model.cluster import WorkerManagerFactory
|
||||
|
||||
@ -219,7 +246,8 @@ class BaseChat(ABC):
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
|
||||
model_output = await worker_manager.generate(payload)
|
||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
@ -234,11 +262,18 @@ class BaseChat(ABC):
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
### run
|
||||
# result = self.do_action(prompt_define_response)
|
||||
result = await blocking_func_to_async(
|
||||
self._executor, self.do_action, prompt_define_response
|
||||
)
|
||||
metadata = {
|
||||
"model_output": model_output.to_dict(),
|
||||
"ai_response_text": ai_response_text,
|
||||
"prompt_define_response": self._parse_prompt_define_response(
|
||||
prompt_define_response
|
||||
),
|
||||
}
|
||||
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
||||
### run
|
||||
result = await blocking_func_to_async(
|
||||
self._executor, self.do_action, prompt_define_response
|
||||
)
|
||||
|
||||
### llm speaker
|
||||
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||
@ -255,12 +290,14 @@ class BaseChat(ABC):
|
||||
|
||||
view_message = view_message.replace("\n", "\\n")
|
||||
self.current_message.add_view_message(view_message)
|
||||
span.end()
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
self.current_message.add_view_message(
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
span.end(metadata={"error": str(e)})
|
||||
### store dialogue
|
||||
self.memory.append(self.current_message)
|
||||
return self.current_ai_response()
|
||||
@ -513,3 +550,21 @@ class BaseChat(ABC):
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
|
||||
if not prompt_define_response:
|
||||
return ""
|
||||
if isinstance(prompt_define_response, str) or isinstance(
|
||||
prompt_define_response, dict
|
||||
):
|
||||
return prompt_define_response
|
||||
if isinstance(prompt_define_response, tuple):
|
||||
if hasattr(prompt_define_response, "_asdict"):
|
||||
# namedtuple
|
||||
return prompt_define_response._asdict()
|
||||
else:
|
||||
return dict(
|
||||
zip(range(len(prompt_define_response)), prompt_define_response)
|
||||
)
|
||||
else:
|
||||
return prompt_define_response
|
||||
|
@ -11,6 +11,7 @@ from pilot.common.string_utils import extract_content
|
||||
from .prompt import prompt
|
||||
from pilot.component import ComponentType
|
||||
from pilot.base_modules.agent.controller import ModuleAgent
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -51,6 +52,7 @@ class ChatAgent(BaseChat):
|
||||
|
||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict[str, str]:
|
||||
input_values = {
|
||||
"user_goal": self.current_user_input,
|
||||
@ -63,7 +65,10 @@ class ChatAgent(BaseChat):
|
||||
|
||||
def stream_plugin_call(self, text):
|
||||
text = text.replace("\n", " ")
|
||||
return self.api_call.run(text)
|
||||
with root_tracer.start_span(
|
||||
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
|
||||
):
|
||||
return self.api_call.run(text)
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
from pilot.scene.chat_dashboard.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
from pilot.utils.executor_utils import blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -53,6 +54,7 @@ class ChatDashboard(BaseChat):
|
||||
data = f.read()
|
||||
return json.loads(data)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
try:
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
@ -14,6 +14,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||
from pilot.common.path_utils import has_path
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.base_modules.agent.common.schema import Status
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -62,6 +63,7 @@ class ChatExcel(BaseChat):
|
||||
# ]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"user_input": self.current_user_input,
|
||||
@ -88,4 +90,9 @@ class ChatExcel(BaseChat):
|
||||
|
||||
def stream_plugin_call(self, text):
|
||||
text = text.replace("\n", " ")
|
||||
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
|
||||
with root_tracer.start_span(
|
||||
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
|
||||
):
|
||||
return self.api_call.run_display_sql(
|
||||
text, self.excel_reader.get_df_by_sql_ex
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from pilot.json_utils.utilities import DateTimeEncoder
|
||||
from pilot.utils.executor_utils import blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -44,6 +45,7 @@ class ExcelLearning(BaseChat):
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
# colunms, datas = self.excel_reader.get_sample_data()
|
||||
colunms, datas = await blocking_func_to_async(
|
||||
|
@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
from pilot.utils.executor_utils import blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -35,10 +36,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
raise ValueError(
|
||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||
)
|
||||
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||
with root_tracer.start_span(
|
||||
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
|
||||
):
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||
self.top_k: int = 200
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
"""
|
||||
generate input values
|
||||
@ -55,13 +59,14 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
# query=self.current_user_input,
|
||||
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
# )
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
self.db_name,
|
||||
self.current_user_input,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
self.db_name,
|
||||
self.current_user_input,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
if not table_infos:
|
||||
@ -80,4 +85,8 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
return self.database.run(prompt_response.sql)
|
||||
with root_tracer.start_span(
|
||||
"ChatWithDbAutoExecute.do_action.run_sql",
|
||||
metadata=prompt_response.to_dict(),
|
||||
):
|
||||
return self.database.run(prompt_response.sql)
|
||||
|
@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
|
||||
sql: str
|
||||
thoughts: Dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict]:
|
||||
return {"sql": self.sql, "thoughts": self.thoughts}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||
from pilot.utils.executor_utils import blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -39,6 +40,7 @@ class ChatWithDbQA(BaseChat):
|
||||
else len(self.tables)
|
||||
)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
table_info = ""
|
||||
dialect = "mysql"
|
||||
|
@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
||||
from pilot.base_modules.agent.commands.command import execute_command
|
||||
from pilot.base_modules.agent import PluginPromptGenerator
|
||||
from .prompt import prompt
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -50,6 +51,7 @@ class ChatWithPlugin(BaseChat):
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
|
@ -4,6 +4,7 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -31,6 +32,7 @@ class InnerChatDBSummary(BaseChat):
|
||||
self.db_input = db_select
|
||||
self.db_summary = db_summary
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"db_input": self.db_input,
|
||||
|
@ -15,6 +15,7 @@ from pilot.configs.model_config import (
|
||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
from pilot.utils.executor_utils import blocking_func_to_async
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -92,6 +93,7 @@ class ChatKnowledge(BaseChat):
|
||||
"""return reference"""
|
||||
return text + f"\n\n{self.parse_source_view(self.sources)}"
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
|
@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
from pilot.utils.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -21,6 +22,7 @@ class ChatNormal(BaseChat):
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {"input": self.current_user_input}
|
||||
return input_values
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Callable, Awaitable, Any
|
||||
import asyncio
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
@ -55,6 +56,12 @@ async def blocking_func_to_async(
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
raise ValueError(f"The function {func} is not blocking function")
|
||||
|
||||
# This function will be called within the new thread, capturing the current context
|
||||
ctx = contextvars.copy_context()
|
||||
|
||||
def run_with_context():
|
||||
return ctx.run(partial(func, *args, **kwargs))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
sync_function_noargs = partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(executor, sync_function_noargs)
|
||||
return await loop.run_in_executor(executor, run_with_context)
|
||||
|
@ -10,6 +10,7 @@ from pilot.utils.tracer.base import (
|
||||
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
|
||||
from pilot.utils.tracer.tracer_impl import (
|
||||
root_tracer,
|
||||
trace,
|
||||
initialize_tracer,
|
||||
DefaultTracer,
|
||||
TracerManager,
|
||||
@ -26,6 +27,7 @@ __all__ = [
|
||||
"MemorySpanStorage",
|
||||
"FileSpanStorage",
|
||||
"root_tracer",
|
||||
"trace",
|
||||
"initialize_tracer",
|
||||
"DefaultTracer",
|
||||
"TracerManager",
|
||||
|
@ -303,8 +303,6 @@ def chat(
|
||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
if sys_table:
|
||||
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
|
||||
if hide_conv:
|
||||
return
|
||||
|
||||
if not found_trace_id:
|
||||
print(f"Can't found conversation with trace_id: {trace_id}")
|
||||
@ -315,9 +313,12 @@ def chat(
|
||||
trace_spans = [s for s in reversed(trace_spans)]
|
||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||
if tree:
|
||||
print("\nInvoke Trace Tree:\n")
|
||||
print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
|
||||
_print_trace_hierarchy(hierarchy)
|
||||
|
||||
if hide_conv:
|
||||
return
|
||||
|
||||
trace_spans = _get_ordered_trace_from(hierarchy)
|
||||
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
||||
split_long_text = output == "text"
|
||||
@ -340,36 +341,43 @@ def chat(
|
||||
table.add_row(["echo", metadata.get("echo")])
|
||||
elif "error" in metadata:
|
||||
table.add_row(["BaseChat Error", metadata.get("error")])
|
||||
if op == "BaseChat.nostream_call" and not sp["end_time"]:
|
||||
if "model_output" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat model_output",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("model_output").get("text"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if "ai_response_text" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat ai_response_text",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("ai_response_text"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
if "prompt_define_response" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat prompt_define_response",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("prompt_define_response"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
if op == "BaseChat.do_action" and not sp["end_time"]:
|
||||
if "model_output" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat model_output",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("model_output").get("text"),
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if "ai_response_text" in metadata:
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat ai_response_text",
|
||||
split_string_by_terminal_width(
|
||||
metadata.get("ai_response_text"), split=split_long_text
|
||||
),
|
||||
]
|
||||
)
|
||||
if "prompt_define_response" in metadata:
|
||||
prompt_define_response = metadata.get("prompt_define_response") or ""
|
||||
if isinstance(prompt_define_response, dict) or isinstance(
|
||||
prompt_define_response, type([])
|
||||
):
|
||||
prompt_define_response = json.dumps(
|
||||
prompt_define_response, ensure_ascii=False
|
||||
)
|
||||
table.add_row(
|
||||
[
|
||||
"BaseChat prompt_define_response",
|
||||
split_string_by_terminal_width(
|
||||
prompt_define_response,
|
||||
split=split_long_text,
|
||||
),
|
||||
]
|
||||
)
|
||||
if op == "DefaultModelWorker_call.generate_stream_func":
|
||||
if not sp["end_time"]:
|
||||
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
|
||||
|
@ -1,6 +1,9 @@
|
||||
from typing import Dict, Optional
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
|
||||
from pilot.component import SystemApp, ComponentType
|
||||
from pilot.utils.tracer.base import (
|
||||
@ -154,18 +157,42 @@ class TracerManager:
|
||||
root_tracer: TracerManager = TracerManager()
|
||||
|
||||
|
||||
def trace(operation_name: str, **trace_kwargs):
|
||||
def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
with root_tracer.start_span(operation_name, **trace_kwargs):
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_operation_name(func, *args):
|
||||
self_name = None
|
||||
if inspect.signature(func).parameters.get("self"):
|
||||
self_name = args[0].__class__.__name__
|
||||
func_name = func.__name__
|
||||
if self_name:
|
||||
return f"{self_name}.{func_name}"
|
||||
return func_name
|
||||
|
||||
|
||||
def initialize_tracer(
|
||||
system_app: SystemApp,
|
||||
tracer_filename: str,
|
||||
|
Loading…
Reference in New Issue
Block a user