feat(core): More trace records for DB-GPT (#775)

- More trace record for DB-GPT
- Support pass span id to threadpool
This commit is contained in:
Aries-ckt 2023-11-04 18:15:53 +08:00 committed by GitHub
commit 1d2b054372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 195 additions and 57 deletions

View File

@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content', `result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`), PRIMARY KEY (`id`),

View File

@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@ -38,6 +39,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@trace("BaseChat.__init__")
def __init__(self, chat_param: Dict): def __init__(self, chat_param: Dict):
"""Chat Module Initialization """Chat Module Initialization
Args: Args:
@ -143,7 +145,14 @@ class BaseChat(ABC):
) )
self.current_message.tokens = 0 self.current_message.tokens = 0
if self.prompt_template.template: if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values) metadata = {
"template_scene": self.prompt_template.template_scene,
"input_values": input_values,
}
with root_tracer.start_span(
"BaseChat.__call_base.prompt_template.format", metadata=metadata
):
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt) self.current_message.add_system_message(current_prompt)
llm_messages = self.generate_llm_messages() llm_messages = self.generate_llm_messages()
@ -175,6 +184,14 @@ class BaseChat(ABC):
except StopAsyncIteration: except StopAsyncIteration:
return True # 迭代器已经执行结束 return True # 迭代器已经执行结束
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
async def stream_call(self): async def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = await self.__call_base() payload = await self.__call_base()
@ -182,6 +199,10 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
@ -199,6 +220,7 @@ class BaseChat(ABC):
self.current_message.add_ai_message(msg) self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg) view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg) self.current_message.add_view_message(view_msg)
span.end()
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
@ -206,12 +228,17 @@ class BaseChat(ABC):
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
### store current conversation ### store current conversation
span.end(metadata={"error": str(e)})
self.memory.append(self.current_message) self.memory.append(self.current_message)
async def nostream_call(self): async def nostream_call(self):
payload = await self.__call_base() payload = await self.__call_base()
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
@ -219,7 +246,8 @@ class BaseChat(ABC):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
model_output = await worker_manager.generate(payload) with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await worker_manager.generate(payload)
### output parse ### output parse
ai_response_text = ( ai_response_text = (
@ -234,11 +262,18 @@ class BaseChat(ABC):
ai_response_text ai_response_text
) )
) )
### run metadata = {
# result = self.do_action(prompt_define_response) "model_output": model_output.to_dict(),
result = await blocking_func_to_async( "ai_response_text": ai_response_text,
self._executor, self.do_action, prompt_define_response "prompt_define_response": self._parse_prompt_define_response(
) prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
### llm speaker ### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response) speak_to_user = self.get_llm_speak(prompt_define_response)
@ -255,12 +290,14 @@ class BaseChat(ABC):
view_message = view_message.replace("\n", "\\n") view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
span.end()
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
self.current_message.add_view_message( self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
span.end(metadata={"error": str(e)})
### store dialogue ### store dialogue
self.memory.append(self.current_message) self.memory.append(self.current_message)
return self.current_ai_response() return self.current_ai_response()
@ -513,3 +550,21 @@ class BaseChat(ABC):
""" """
pass pass
def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
if not prompt_define_response:
return ""
if isinstance(prompt_define_response, str) or isinstance(
prompt_define_response, dict
):
return prompt_define_response
if isinstance(prompt_define_response, tuple):
if hasattr(prompt_define_response, "_asdict"):
# namedtuple
return prompt_define_response._asdict()
else:
return dict(
zip(range(len(prompt_define_response)), prompt_define_response)
)
else:
return prompt_define_response

View File

@ -11,6 +11,7 @@ from pilot.common.string_utils import extract_content
from .prompt import prompt from .prompt import prompt
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent from pilot.base_modules.agent.controller import ModuleAgent
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -51,6 +52,7 @@ class ChatAgent(BaseChat):
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
@trace()
async def generate_input_values(self) -> Dict[str, str]: async def generate_input_values(self) -> Dict[str, str]:
input_values = { input_values = {
"user_goal": self.current_user_input, "user_goal": self.current_user_input,
@ -63,7 +65,10 @@ class ChatAgent(BaseChat):
def stream_plugin_call(self, text): def stream_plugin_call(self, text):
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.api_call.run(text) with root_tracer.start_span(
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
):
return self.api_call.run(text)
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))

View File

@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -53,6 +54,7 @@ class ChatDashboard(BaseChat):
data = f.read() data = f.read()
return json.loads(data) return json.loads(data)
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient

View File

@ -14,6 +14,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.common.path_utils import has_path from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status from pilot.base_modules.agent.common.schema import Status
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -62,6 +63,7 @@ class ChatExcel(BaseChat):
# ] # ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
@ -88,4 +90,9 @@ class ChatExcel(BaseChat):
def stream_plugin_call(self, text): def stream_plugin_call(self, text):
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) with root_tracer.start_span(
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
):
return self.api_call.run_display_sql(
text, self.excel_reader.get_df_by_sql_ex
)

View File

@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -44,6 +45,7 @@ class ExcelLearning(BaseChat):
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
# colunms, datas = self.excel_reader.get_sample_data() # colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async( colunms, datas = await blocking_func_to_async(

View File

@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -35,10 +36,13 @@ class ChatWithDbAutoExecute(BaseChat):
raise ValueError( raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
) )
with root_tracer.start_span(
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200 self.top_k: int = 200
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
""" """
generate input values generate input values
@ -55,13 +59,14 @@ class ChatWithDbAutoExecute(BaseChat):
# query=self.current_user_input, # query=self.current_user_input,
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# ) # )
table_infos = await blocking_func_to_async( with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
self._executor, table_infos = await blocking_func_to_async(
client.get_db_summary, self._executor,
self.db_name, client.get_db_summary,
self.current_user_input, self.db_name,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE, self.current_user_input,
) CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
if not table_infos: if not table_infos:
@ -80,4 +85,8 @@ class ChatWithDbAutoExecute(BaseChat):
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}") print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql) with root_tracer.start_span(
"ChatWithDbAutoExecute.do_action.run_sql",
metadata=prompt_response.to_dict(),
):
return self.database.run(prompt_response.sql)

View File

@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
sql: str sql: str
thoughts: Dict thoughts: Dict
def to_dict(self) -> Dict[str, Dict]:
return {"sql": self.sql, "thoughts": self.thoughts}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -39,6 +40,7 @@ class ChatWithDbQA(BaseChat):
else len(self.tables) else len(self.tables)
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
table_info = "" table_info = ""
dialect = "mysql" dialect = "mysql"

View File

@ -6,6 +6,7 @@ from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command import execute_command from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator from pilot.base_modules.agent import PluginPromptGenerator
from .prompt import prompt from .prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -50,6 +51,7 @@ class ChatWithPlugin(BaseChat):
self.plugins_prompt_generator self.plugins_prompt_generator
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,

View File

@ -4,6 +4,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -31,6 +32,7 @@ class InnerChatDBSummary(BaseChat):
self.db_input = db_select self.db_input = db_select
self.db_summary = db_summary self.db_summary = db_summary
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"db_input": self.db_input, "db_input": self.db_input,

View File

@ -15,6 +15,7 @@ from pilot.configs.model_config import (
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -92,6 +93,7 @@ class ChatKnowledge(BaseChat):
"""return reference""" """return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}" return text + f"\n\n{self.parse_source_view(self.sources)}"
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_normal.prompt import prompt from pilot.scene.chat_normal.prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@ -21,6 +22,7 @@ class ChatNormal(BaseChat):
chat_param=chat_param, chat_param=chat_param,
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values

View File

@ -1,5 +1,6 @@
from typing import Callable, Awaitable, Any from typing import Callable, Awaitable, Any
import asyncio import asyncio
import contextvars
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial from functools import partial
@ -55,6 +56,12 @@ async def blocking_func_to_async(
""" """
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function") raise ValueError(f"The function {func} is not blocking function")
# This function will be called within the new thread, capturing the current context
ctx = contextvars.copy_context()
def run_with_context():
return ctx.run(partial(func, *args, **kwargs))
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
sync_function_noargs = partial(func, *args, **kwargs) return await loop.run_in_executor(executor, run_with_context)
return await loop.run_in_executor(executor, sync_function_noargs)

View File

@ -10,6 +10,7 @@ from pilot.utils.tracer.base import (
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
from pilot.utils.tracer.tracer_impl import ( from pilot.utils.tracer.tracer_impl import (
root_tracer, root_tracer,
trace,
initialize_tracer, initialize_tracer,
DefaultTracer, DefaultTracer,
TracerManager, TracerManager,
@ -26,6 +27,7 @@ __all__ = [
"MemorySpanStorage", "MemorySpanStorage",
"FileSpanStorage", "FileSpanStorage",
"root_tracer", "root_tracer",
"trace",
"initialize_tracer", "initialize_tracer",
"DefaultTracer", "DefaultTracer",
"TracerManager", "TracerManager",

View File

@ -303,8 +303,6 @@ def chat(
print(table.get_formatted_string(out_format=output, **out_kwargs)) print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table: if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs)) print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
if hide_conv:
return
if not found_trace_id: if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}") print(f"Can't found conversation with trace_id: {trace_id}")
@ -315,9 +313,12 @@ def chat(
trace_spans = [s for s in reversed(trace_spans)] trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans) hierarchy = _build_trace_hierarchy(trace_spans)
if tree: if tree:
print("\nInvoke Trace Tree:\n") print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
_print_trace_hierarchy(hierarchy) _print_trace_hierarchy(hierarchy)
if hide_conv:
return
trace_spans = _get_ordered_trace_from(hierarchy) trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details") table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text" split_long_text = output == "text"
@ -340,36 +341,43 @@ def chat(
table.add_row(["echo", metadata.get("echo")]) table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata: elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")]) table.add_row(["BaseChat Error", metadata.get("error")])
if op == "BaseChat.nostream_call" and not sp["end_time"]: if op == "BaseChat.do_action" and not sp["end_time"]:
if "model_output" in metadata: if "model_output" in metadata:
table.add_row( table.add_row(
[ [
"BaseChat model_output", "BaseChat model_output",
split_string_by_terminal_width( split_string_by_terminal_width(
metadata.get("model_output").get("text"), metadata.get("model_output").get("text"),
split=split_long_text, split=split_long_text,
), ),
] ]
) )
if "ai_response_text" in metadata: if "ai_response_text" in metadata:
table.add_row( table.add_row(
[ [
"BaseChat ai_response_text", "BaseChat ai_response_text",
split_string_by_terminal_width( split_string_by_terminal_width(
metadata.get("ai_response_text"), split=split_long_text metadata.get("ai_response_text"), split=split_long_text
), ),
] ]
) )
if "prompt_define_response" in metadata: if "prompt_define_response" in metadata:
table.add_row( prompt_define_response = metadata.get("prompt_define_response") or ""
[ if isinstance(prompt_define_response, dict) or isinstance(
"BaseChat prompt_define_response", prompt_define_response, type([])
split_string_by_terminal_width( ):
metadata.get("prompt_define_response"), prompt_define_response = json.dumps(
split=split_long_text, prompt_define_response, ensure_ascii=False
),
]
) )
table.add_row(
[
"BaseChat prompt_define_response",
split_string_by_terminal_width(
prompt_define_response,
split=split_long_text,
),
]
)
if op == "DefaultModelWorker_call.generate_stream_func": if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]: if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")]) table.add_row(["llm_adapter", metadata.get("llm_adapter")])

View File

@ -1,6 +1,9 @@
from typing import Dict, Optional from typing import Dict, Optional
from contextvars import ContextVar from contextvars import ContextVar
from functools import wraps from functools import wraps
import asyncio
import inspect
from pilot.component import SystemApp, ComponentType from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import ( from pilot.utils.tracer.base import (
@ -154,18 +157,42 @@ class TracerManager:
root_tracer: TracerManager = TracerManager() root_tracer: TracerManager = TracerManager()
def trace(operation_name: str, **trace_kwargs): def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
with root_tracer.start_span(operation_name, **trace_kwargs): name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator return decorator
def _parse_operation_name(func, *args):
self_name = None
if inspect.signature(func).parameters.get("self"):
self_name = args[0].__class__.__name__
func_name = func.__name__
if self_name:
return f"{self_name}.{func_name}"
return func_name
def initialize_tracer( def initialize_tracer(
system_app: SystemApp, system_app: SystemApp,
tracer_filename: str, tracer_filename: str,