feat(core): More trace record for DB-GPT

This commit is contained in:
FangYin Cheng
2023-11-03 23:32:12 +08:00
parent d685f834ba
commit 23347d52a9
5 changed files with 125 additions and 39 deletions

View File

@@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.utils import get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
from pydantic import Extra from pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
@@ -38,6 +39,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@trace("BaseChat.__init__")
def __init__(self, chat_param: Dict): def __init__(self, chat_param: Dict):
"""Chat Module Initialization """Chat Module Initialization
Args: Args:
@@ -175,6 +177,14 @@ class BaseChat(ABC):
except StopAsyncIteration: except StopAsyncIteration:
return True # 迭代器已经执行结束 return True # 迭代器已经执行结束
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
async def stream_call(self): async def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = await self.__call_base() payload = await self.__call_base()
@@ -182,6 +192,10 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
@@ -199,6 +213,7 @@ class BaseChat(ABC):
self.current_message.add_ai_message(msg) self.current_message.add_ai_message(msg)
view_msg = self.knowledge_reference_call(msg) view_msg = self.knowledge_reference_call(msg)
self.current_message.add_view_message(view_msg) self.current_message.add_view_message(view_msg)
span.end()
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
@@ -206,12 +221,17 @@ class BaseChat(ABC):
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
### store current conversation ### store current conversation
span.end(metadata={"error": str(e)})
self.memory.append(self.current_message) self.memory.append(self.current_message)
async def nostream_call(self): async def nostream_call(self):
payload = await self.__call_base() payload = await self.__call_base()
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
)
payload["span_id"] = span.span_id
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory
@@ -219,6 +239,7 @@ class BaseChat(ABC):
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create() ).create()
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await worker_manager.generate(payload) model_output = await worker_manager.generate(payload)
### output parse ### output parse
@@ -234,8 +255,15 @@ class BaseChat(ABC):
ai_response_text ai_response_text
) )
) )
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run ### run
# result = self.do_action(prompt_define_response)
result = await blocking_func_to_async( result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response self._executor, self.do_action, prompt_define_response
) )
@@ -255,12 +283,14 @@ class BaseChat(ABC):
view_message = view_message.replace("\n", "\\n") view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
span.end()
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
self.current_message.add_view_message( self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """ f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
) )
span.end(metadata={"error": str(e)})
### store dialogue ### store dialogue
self.memory.append(self.current_message) self.memory.append(self.current_message)
return self.current_ai_response() return self.current_ai_response()
@@ -513,3 +543,21 @@ class BaseChat(ABC):
""" """
pass pass
def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
if not prompt_define_response:
return ""
if isinstance(prompt_define_response, str) or isinstance(
prompt_define_response, dict
):
return prompt_define_response
if isinstance(prompt_define_response, tuple):
if hasattr(prompt_define_response, "_asdict"):
# namedtuple
return prompt_define_response._asdict()
else:
return dict(
zip(range(len(prompt_define_response)), prompt_define_response)
)
else:
return prompt_define_response

View File

@@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_normal.prompt import prompt from pilot.scene.chat_normal.prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -21,6 +22,7 @@ class ChatNormal(BaseChat):
chat_param=chat_param, chat_param=chat_param,
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values

View File

@@ -10,6 +10,7 @@ from pilot.utils.tracer.base import (
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
from pilot.utils.tracer.tracer_impl import ( from pilot.utils.tracer.tracer_impl import (
root_tracer, root_tracer,
trace,
initialize_tracer, initialize_tracer,
DefaultTracer, DefaultTracer,
TracerManager, TracerManager,
@@ -26,6 +27,7 @@ __all__ = [
"MemorySpanStorage", "MemorySpanStorage",
"FileSpanStorage", "FileSpanStorage",
"root_tracer", "root_tracer",
"trace",
"initialize_tracer", "initialize_tracer",
"DefaultTracer", "DefaultTracer",
"TracerManager", "TracerManager",

View File

@@ -340,7 +340,7 @@ def chat(
table.add_row(["echo", metadata.get("echo")]) table.add_row(["echo", metadata.get("echo")])
elif "error" in metadata: elif "error" in metadata:
table.add_row(["BaseChat Error", metadata.get("error")]) table.add_row(["BaseChat Error", metadata.get("error")])
if op == "BaseChat.nostream_call" and not sp["end_time"]: if op == "BaseChat.do_action" and not sp["end_time"]:
if "model_output" in metadata: if "model_output" in metadata:
table.add_row( table.add_row(
[ [
@@ -361,11 +361,18 @@ def chat(
] ]
) )
if "prompt_define_response" in metadata: if "prompt_define_response" in metadata:
prompt_define_response = metadata.get("prompt_define_response") or ""
if isinstance(prompt_define_response, dict) or isinstance(
prompt_define_response, type([])
):
prompt_define_response = json.dumps(
prompt_define_response, ensure_ascii=False
)
table.add_row( table.add_row(
[ [
"BaseChat prompt_define_response", "BaseChat prompt_define_response",
split_string_by_terminal_width( split_string_by_terminal_width(
metadata.get("prompt_define_response"), prompt_define_response,
split=split_long_text, split=split_long_text,
), ),
] ]

View File

@@ -1,6 +1,9 @@
from typing import Dict, Optional from typing import Dict, Optional
from contextvars import ContextVar from contextvars import ContextVar
from functools import wraps from functools import wraps
import asyncio
import inspect
from pilot.component import SystemApp, ComponentType from pilot.component import SystemApp, ComponentType
from pilot.utils.tracer.base import ( from pilot.utils.tracer.base import (
@@ -154,18 +157,42 @@ class TracerManager:
root_tracer: TracerManager = TracerManager() root_tracer: TracerManager = TracerManager()
def trace(operation_name: str, **trace_kwargs): def trace(operation_name: Optional[str] = None, **trace_kwargs):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
with root_tracer.start_span(operation_name, **trace_kwargs): name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
name = (
operation_name if operation_name else _parse_operation_name(func, *args)
)
with root_tracer.start_span(name, **trace_kwargs):
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator return decorator
def _parse_operation_name(func, *args):
self_name = None
if inspect.signature(func).parameters.get("self"):
self_name = args[0].__class__.__name__
func_name = func.__name__
if self_name:
return f"{self_name}.{func_name}"
return func_name
def initialize_tracer( def initialize_tracer(
system_app: SystemApp, system_app: SystemApp,
tracer_filename: str, tracer_filename: str,