From 23347d52a9217d1ea62211ea8a17c6887eb86ec6 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Fri, 3 Nov 2023 23:32:12 +0800 Subject: [PATCH] feat(core): More trace record for DB-GPT --- pilot/scene/base_chat.py | 60 +++++++++++++++++++++++++--- pilot/scene/chat_normal/chat.py | 2 + pilot/utils/tracer/__init__.py | 2 + pilot/utils/tracer/tracer_cli.py | 65 +++++++++++++++++-------------- pilot/utils/tracer/tracer_impl.py | 35 +++++++++++++++-- 5 files changed, 125 insertions(+), 39 deletions(-) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 34f294c31..294bd04ca 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.message import OnceConversation from pilot.utils import get_or_create_event_loop from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace from pydantic import Extra from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory @@ -38,6 +39,7 @@ class BaseChat(ABC): arbitrary_types_allowed = True + @trace("BaseChat.__init__") def __init__(self, chat_param: Dict): """Chat Module Initialization Args: @@ -175,6 +177,14 @@ class BaseChat(ABC): except StopAsyncIteration: return True # 迭代器已经执行结束 + def _get_span_metadata(self, payload: Dict) -> Dict: + metadata = {k: v for k, v in payload.items()} + del metadata["prompt"] + metadata["messages"] = list( + map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"]) + ) + return metadata + async def stream_call(self): # TODO Retry when server connection error payload = await self.__call_base() @@ -182,6 +192,10 @@ class BaseChat(ABC): self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" + span = root_tracer.start_span( + "BaseChat.stream_call", metadata=self._get_span_metadata(payload) + ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -199,6 +213,7 @@ class BaseChat(ABC): self.current_message.add_ai_message(msg) view_msg = self.knowledge_reference_call(msg) self.current_message.add_view_message(view_msg) + span.end() except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) @@ -206,12 +221,17 @@ class BaseChat(ABC): f"""ERROR!{str(e)}\n {ai_response_text} """ ) ### store current conversation + span.end(metadata={"error": str(e)}) self.memory.append(self.current_message) async def nostream_call(self): payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" + span = root_tracer.start_span( + "BaseChat.nostream_call", metadata=self._get_span_metadata(payload) + ) + payload["span_id"] = span.span_id try: from pilot.model.cluster import WorkerManagerFactory @@ -219,7 +239,8 @@ class BaseChat(ABC): ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() - model_output = await worker_manager.generate(payload) + with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): + model_output = await worker_manager.generate(payload) ### output parse ai_response_text = ( @@ -234,11 +255,18 @@ class BaseChat(ABC): ai_response_text ) ) - ### run - # result = self.do_action(prompt_define_response) - result = await blocking_func_to_async( - self._executor, self.do_action, prompt_define_response - ) + metadata = { + "model_output": model_output.to_dict(), + "ai_response_text": ai_response_text, + "prompt_define_response": self._parse_prompt_define_response( + prompt_define_response + ), + } + with root_tracer.start_span("BaseChat.do_action", metadata=metadata): + ### run + result = await blocking_func_to_async( + self._executor, self.do_action, prompt_define_response + ) ### llm speaker speak_to_user = self.get_llm_speak(prompt_define_response) @@ -255,12 +283,14 @@ class BaseChat(ABC): view_message = view_message.replace("\n", "\\n") self.current_message.add_view_message(view_message) + span.end() except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) + span.end(metadata={"error": str(e)}) ### store dialogue self.memory.append(self.current_message) return self.current_ai_response() @@ -513,3 +543,21 @@ class BaseChat(ABC): """ pass + + def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any: + if not prompt_define_response: + return "" + if isinstance(prompt_define_response, str) or isinstance( + prompt_define_response, dict + ): + return prompt_define_response + if isinstance(prompt_define_response, tuple): + if hasattr(prompt_define_response, "_asdict"): + # namedtuple + return prompt_define_response._asdict() + else: + return dict( + zip(range(len(prompt_define_response)), prompt_define_response) + ) + else: + return prompt_define_response diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index 5999d5c3c..0191ef943 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene from pilot.configs.config import Config from pilot.scene.chat_normal.prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -21,6 +22,7 @@ class ChatNormal(BaseChat): chat_param=chat_param, ) + @trace() async def generate_input_values(self) -> Dict: input_values = {"input": self.current_user_input} return input_values diff --git a/pilot/utils/tracer/__init__.py b/pilot/utils/tracer/__init__.py index 16509ff43..cdb536f79 100644 --- a/pilot/utils/tracer/__init__.py +++ b/pilot/utils/tracer/__init__.py @@ -10,6 +10,7 @@ from pilot.utils.tracer.base import ( from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage from pilot.utils.tracer.tracer_impl import ( root_tracer, + trace, initialize_tracer, DefaultTracer, TracerManager, @@ -26,6 +27,7 @@ __all__ = [ "MemorySpanStorage", "FileSpanStorage", "root_tracer", + "trace", "initialize_tracer", "DefaultTracer", "TracerManager", diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py index 7df18f516..859fa4022 100644 --- a/pilot/utils/tracer/tracer_cli.py +++ b/pilot/utils/tracer/tracer_cli.py @@ -340,36 +340,43 @@ def chat( table.add_row(["echo", metadata.get("echo")]) elif "error" in metadata: table.add_row(["BaseChat Error", metadata.get("error")]) - if op == "BaseChat.nostream_call" and not sp["end_time"]: - if "model_output" in metadata: - table.add_row( - [ - "BaseChat model_output", - split_string_by_terminal_width( - metadata.get("model_output").get("text"), - split=split_long_text, - ), - ] - ) - if "ai_response_text" in metadata: - table.add_row( - [ - "BaseChat ai_response_text", - split_string_by_terminal_width( - metadata.get("ai_response_text"), split=split_long_text - ), - ] - ) - if "prompt_define_response" in metadata: - table.add_row( - [ - "BaseChat prompt_define_response", - split_string_by_terminal_width( - metadata.get("prompt_define_response"), - split=split_long_text, - ), - ] + if op == "BaseChat.do_action" and not sp["end_time"]: + if "model_output" in metadata: + table.add_row( + [ + "BaseChat model_output", + split_string_by_terminal_width( + metadata.get("model_output").get("text"), + split=split_long_text, + ), + ] + ) + if "ai_response_text" in metadata: + table.add_row( + [ + "BaseChat ai_response_text", + split_string_by_terminal_width( + metadata.get("ai_response_text"), split=split_long_text + ), + ] + ) + if "prompt_define_response" in metadata: + prompt_define_response = metadata.get("prompt_define_response") or "" + if isinstance(prompt_define_response, dict) or isinstance( + prompt_define_response, type([]) + ): + prompt_define_response = json.dumps( + prompt_define_response, ensure_ascii=False ) + table.add_row( + [ + "BaseChat prompt_define_response", + split_string_by_terminal_width( + prompt_define_response, + split=split_long_text, + ), + ] + ) if op == "DefaultModelWorker_call.generate_stream_func": if not sp["end_time"]: table.add_row(["llm_adapter", metadata.get("llm_adapter")]) diff --git a/pilot/utils/tracer/tracer_impl.py b/pilot/utils/tracer/tracer_impl.py index bda25ab4d..2358863bf 100644 --- a/pilot/utils/tracer/tracer_impl.py +++ b/pilot/utils/tracer/tracer_impl.py @@ -1,6 +1,9 @@ from typing import Dict, Optional from contextvars import ContextVar from functools import wraps +import asyncio +import inspect + from pilot.component import SystemApp, ComponentType from pilot.utils.tracer.base import ( @@ -154,18 +157,42 @@ class TracerManager: root_tracer: TracerManager = TracerManager() -def trace(operation_name: str, **trace_kwargs): +def trace(operation_name: Optional[str] = None, **trace_kwargs): def decorator(func): @wraps(func) - async def wrapper(*args, **kwargs): - with root_tracer.start_span(operation_name, **trace_kwargs): + def sync_wrapper(*args, **kwargs): + name = ( + operation_name if operation_name else _parse_operation_name(func, *args) + ) + with root_tracer.start_span(name, **trace_kwargs): + return func(*args, **kwargs) + + @wraps(func) + async def async_wrapper(*args, **kwargs): + name = ( + operation_name if operation_name else _parse_operation_name(func, *args) + ) + with root_tracer.start_span(name, **trace_kwargs): return await func(*args, **kwargs) - return wrapper + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper return decorator +def _parse_operation_name(func, *args): + self_name = None + if inspect.signature(func).parameters.get("self"): + self_name = args[0].__class__.__name__ + func_name = func.__name__ + if self_name: + return f"{self_name}.{func_name}" + return func_name + + def initialize_tracer( system_app: SystemApp, tracer_filename: str,