diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py
index 34f294c31..294bd04ca 100644
--- a/pilot/scene/base_chat.py
+++ b/pilot/scene/base_chat.py
@@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
+from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@@ -38,6 +39,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
+ @trace("BaseChat.__init__")
def __init__(self, chat_param: Dict):
"""Chat Module Initialization
Args:
@@ -175,6 +177,14 @@ class BaseChat(ABC):
except StopAsyncIteration:
return True # 迭代器已经执行结束
+ def _get_span_metadata(self, payload: Dict) -> Dict:
+ metadata = {k: v for k, v in payload.items()}
+ del metadata["prompt"]
+ metadata["messages"] = list(
+ map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
+ )
+ return metadata
+
async def stream_call(self):
# TODO Retry when server connection error
payload = await self.__call_base()
@@ -182,6 +192,10 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
+ span = root_tracer.start_span(
+ "BaseChat.stream_call", metadata=self._get_span_metadata(payload)
+ )
+ payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@@ -199,6 +213,7 @@ class BaseChat(ABC):
self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg)
+ span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
@@ -206,12 +221,17 @@ class BaseChat(ABC):
f"""ERROR!{str(e)}\n {ai_response_text} """
)
### store current conversation
+ span.end(metadata={"error": str(e)})
self.memory.append(self.current_message)
async def nostream_call(self):
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
+ span = root_tracer.start_span(
+ "BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
+ )
+ payload["span_id"] = span.span_id
try:
from pilot.model.cluster import WorkerManagerFactory
@@ -219,7 +239,8 @@ class BaseChat(ABC):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
- model_output = await worker_manager.generate(payload)
+ with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
+ model_output = await worker_manager.generate(payload)
### output parse
ai_response_text = (
@@ -234,11 +255,18 @@ class BaseChat(ABC):
ai_response_text
)
)
- ### run
- # result = self.do_action(prompt_define_response)
- result = await blocking_func_to_async(
- self._executor, self.do_action, prompt_define_response
- )
+ metadata = {
+ "model_output": model_output.to_dict(),
+ "ai_response_text": ai_response_text,
+ "prompt_define_response": self._parse_prompt_define_response(
+ prompt_define_response
+ ),
+ }
+ with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
+ ### run
+ result = await blocking_func_to_async(
+ self._executor, self.do_action, prompt_define_response
+ )
### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
@@ -255,12 +283,14 @@ class BaseChat(ABC):
view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
+ span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""ERROR!{str(e)}\n {ai_response_text} """
)
+ span.end(metadata={"error": str(e)})
### store dialogue
self.memory.append(self.current_message)
return self.current_ai_response()
@@ -513,3 +543,21 @@ class BaseChat(ABC):
"""
pass
+
+ def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
+ if not prompt_define_response:
+ return ""
+ if isinstance(prompt_define_response, str) or isinstance(
+ prompt_define_response, dict
+ ):
+ return prompt_define_response
+ if isinstance(prompt_define_response, tuple):
+ if hasattr(prompt_define_response, "_asdict"):
+ # namedtuple
+ return prompt_define_response._asdict()
+ else:
+ return dict(
+ zip(range(len(prompt_define_response)), prompt_define_response)
+ )
+ else:
+ return prompt_define_response
diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py
index 5999d5c3c..0191ef943 100644
--- a/pilot/scene/chat_normal/chat.py
+++ b/pilot/scene/chat_normal/chat.py
@@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_normal.prompt import prompt
+from pilot.utils.tracer import root_tracer, trace
CFG = Config()
@@ -21,6 +22,7 @@ class ChatNormal(BaseChat):
chat_param=chat_param,
)
+ @trace()
async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input}
return input_values
diff --git a/pilot/utils/tracer/__init__.py b/pilot/utils/tracer/__init__.py
index 16509ff43..cdb536f79 100644
--- a/pilot/utils/tracer/__init__.py
+++ b/pilot/utils/tracer/__init__.py
@@ -10,6 +10,7 @@ from pilot.utils.tracer.base import (
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
from pilot.utils.tracer.tracer_impl import (
root_tracer,
+ trace,
initialize_tracer,
DefaultTracer,
TracerManager,
@@ -26,6 +27,7 @@ __all__ = [
"MemorySpanStorage",
"FileSpanStorage",
"root_tracer",
+ "trace",
"initialize_tracer",
"DefaultTracer",
"TracerManager",
diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py
index 7df18f516..859fa4022 100644
--- a/pilot/utils/tracer/tracer_cli.py
+++ b/pilot/utils/tracer/tracer_cli.py
@@ -340,36 +340,43 @@ def chat(
table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")])
- if op == "BaseChat.nostream_call" and not sp["end_time"]:
- if "model_output" in metadata:
- table.add_row(
- [
- "BaseChat model_output",
- split_string_by_terminal_width(
- metadata.get("model_output").get("text"),
- split=split_long_text,
- ),
- ]
- )
- if "ai_response_text" in metadata:
- table.add_row(
- [
- "BaseChat ai_response_text",
- split_string_by_terminal_width(
- metadata.get("ai_response_text"), split=split_long_text
- ),
- ]
- )
- if "prompt_define_response" in metadata:
- table.add_row(
- [
- "BaseChat prompt_define_response",
- split_string_by_terminal_width(
- metadata.get("prompt_define_response"),
- split=split_long_text,
- ),
- ]
+ if op == "BaseChat.do_action" and not sp["end_time"]:
+ if "model_output" in metadata:
+ table.add_row(
+ [
+ "BaseChat model_output",
+ split_string_by_terminal_width(
+ metadata.get("model_output").get("text"),
+ split=split_long_text,
+ ),
+ ]
+ )
+ if "ai_response_text" in metadata:
+ table.add_row(
+ [
+ "BaseChat ai_response_text",
+ split_string_by_terminal_width(
+ metadata.get("ai_response_text"), split=split_long_text
+ ),
+ ]
+ )
+ if "prompt_define_response" in metadata:
+ prompt_define_response = metadata.get("prompt_define_response") or ""
+ if isinstance(prompt_define_response, dict) or isinstance(
+ prompt_define_response, type([])
+ ):
+ prompt_define_response = json.dumps(
+ prompt_define_response, ensure_ascii=False
)
+ table.add_row(
+ [
+ "BaseChat prompt_define_response",
+ split_string_by_terminal_width(
+ prompt_define_response,
+ split=split_long_text,
+ ),
+ ]
+ )
if op == "DefaultModelWorker_call.generate_stream_func":
if not sp["end_time"]:
table.add_row(["llm_adapter", metadata.get("llm_adapter")])
diff --git a/pilot/utils/tracer/tracer_impl.py b/pilot/utils/tracer/tracer_impl.py
index bda25ab4d..2358863bf 100644
--- a/pilot/utils/tracer/tracer_impl.py
+++ b/pilot/utils/tracer/tracer_impl.py
@@ -1,6 +1,9 @@
from typing import Dict, Optional
from contextvars import ContextVar
from functools import wraps
+import asyncio
+import inspect
+
from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import (
@@ -154,18 +157,42 @@ class TracerManager:
root_tracer: TracerManager = TracerManager()
-def trace(operation_name: str, **trace_kwargs):
+def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func):
@wraps(func)
- async def wrapper(*args, **kwargs):
- with root_tracer.start_span(operation_name, **trace_kwargs):
+ def sync_wrapper(*args, **kwargs):
+ name = (
+ operation_name if operation_name else _parse_operation_name(func, *args)
+ )
+ with root_tracer.start_span(name, **trace_kwargs):
+ return func(*args, **kwargs)
+
+ @wraps(func)
+ async def async_wrapper(*args, **kwargs):
+ name = (
+ operation_name if operation_name else _parse_operation_name(func, *args)
+ )
+ with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs)
- return wrapper
+ if asyncio.iscoroutinefunction(func):
+ return async_wrapper
+ else:
+ return sync_wrapper
return decorator
+def _parse_operation_name(func, *args):
+ self_name = None
+ if inspect.signature(func).parameters.get("self"):
+ self_name = args[0].__class__.__name__
+ func_name = func.__name__
+ if self_name:
+ return f"{self_name}.{func_name}"
+ return func_name
+
+
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,