mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(core): More trace record for DB-GPT
This commit is contained in:
@@ -13,6 +13,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
|||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.utils import get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@trace("BaseChat.__init__")
|
||||||
def __init__(self, chat_param: Dict):
|
def __init__(self, chat_param: Dict):
|
||||||
"""Chat Module Initialization
|
"""Chat Module Initialization
|
||||||
Args:
|
Args:
|
||||||
@@ -175,6 +177,14 @@ class BaseChat(ABC):
|
|||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
return True # 迭代器已经执行结束
|
return True # 迭代器已经执行结束
|
||||||
|
|
||||||
|
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||||
|
metadata = {k: v for k, v in payload.items()}
|
||||||
|
del metadata["prompt"]
|
||||||
|
metadata["messages"] = list(
|
||||||
|
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||||
|
)
|
||||||
|
return metadata
|
||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
# TODO Retry when server connection error
|
# TODO Retry when server connection error
|
||||||
payload = await self.__call_base()
|
payload = await self.__call_base()
|
||||||
@@ -182,6 +192,10 @@ class BaseChat(ABC):
|
|||||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
|
span = root_tracer.start_span(
|
||||||
|
"BaseChat.stream_call", metadata=self._get_span_metadata(payload)
|
||||||
|
)
|
||||||
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
@@ -199,6 +213,7 @@ class BaseChat(ABC):
|
|||||||
self.current_message.add_ai_message(msg)
|
self.current_message.add_ai_message(msg)
|
||||||
view_msg = self.knowledge_reference_call(msg)
|
view_msg = self.knowledge_reference_call(msg)
|
||||||
self.current_message.add_view_message(view_msg)
|
self.current_message.add_view_message(view_msg)
|
||||||
|
span.end()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
@@ -206,12 +221,17 @@ class BaseChat(ABC):
|
|||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
### store current conversation
|
### store current conversation
|
||||||
|
span.end(metadata={"error": str(e)})
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
async def nostream_call(self):
|
async def nostream_call(self):
|
||||||
payload = await self.__call_base()
|
payload = await self.__call_base()
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
|
span = root_tracer.start_span(
|
||||||
|
"BaseChat.nostream_call", metadata=self._get_span_metadata(payload)
|
||||||
|
)
|
||||||
|
payload["span_id"] = span.span_id
|
||||||
try:
|
try:
|
||||||
from pilot.model.cluster import WorkerManagerFactory
|
from pilot.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
@@ -219,6 +239,7 @@ class BaseChat(ABC):
|
|||||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
|
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await worker_manager.generate(payload)
|
||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
@@ -234,8 +255,15 @@ class BaseChat(ABC):
|
|||||||
ai_response_text
|
ai_response_text
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
metadata = {
|
||||||
|
"model_output": model_output.to_dict(),
|
||||||
|
"ai_response_text": ai_response_text,
|
||||||
|
"prompt_define_response": self._parse_prompt_define_response(
|
||||||
|
prompt_define_response
|
||||||
|
),
|
||||||
|
}
|
||||||
|
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
||||||
### run
|
### run
|
||||||
# result = self.do_action(prompt_define_response)
|
|
||||||
result = await blocking_func_to_async(
|
result = await blocking_func_to_async(
|
||||||
self._executor, self.do_action, prompt_define_response
|
self._executor, self.do_action, prompt_define_response
|
||||||
)
|
)
|
||||||
@@ -255,12 +283,14 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
view_message = view_message.replace("\n", "\\n")
|
view_message = view_message.replace("\n", "\\n")
|
||||||
self.current_message.add_view_message(view_message)
|
self.current_message.add_view_message(view_message)
|
||||||
|
span.end()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
|
span.end(metadata={"error": str(e)})
|
||||||
### store dialogue
|
### store dialogue
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
return self.current_ai_response()
|
return self.current_ai_response()
|
||||||
@@ -513,3 +543,21 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
|
||||||
|
if not prompt_define_response:
|
||||||
|
return ""
|
||||||
|
if isinstance(prompt_define_response, str) or isinstance(
|
||||||
|
prompt_define_response, dict
|
||||||
|
):
|
||||||
|
return prompt_define_response
|
||||||
|
if isinstance(prompt_define_response, tuple):
|
||||||
|
if hasattr(prompt_define_response, "_asdict"):
|
||||||
|
# namedtuple
|
||||||
|
return prompt_define_response._asdict()
|
||||||
|
else:
|
||||||
|
return dict(
|
||||||
|
zip(range(len(prompt_define_response)), prompt_define_response)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return prompt_define_response
|
||||||
|
@@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.scene.chat_normal.prompt import prompt
|
from pilot.scene.chat_normal.prompt import prompt
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ class ChatNormal(BaseChat):
|
|||||||
chat_param=chat_param,
|
chat_param=chat_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {"input": self.current_user_input}
|
input_values = {"input": self.current_user_input}
|
||||||
return input_values
|
return input_values
|
||||||
|
@@ -10,6 +10,7 @@ from pilot.utils.tracer.base import (
|
|||||||
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
|
from pilot.utils.tracer.span_storage import MemorySpanStorage, FileSpanStorage
|
||||||
from pilot.utils.tracer.tracer_impl import (
|
from pilot.utils.tracer.tracer_impl import (
|
||||||
root_tracer,
|
root_tracer,
|
||||||
|
trace,
|
||||||
initialize_tracer,
|
initialize_tracer,
|
||||||
DefaultTracer,
|
DefaultTracer,
|
||||||
TracerManager,
|
TracerManager,
|
||||||
@@ -26,6 +27,7 @@ __all__ = [
|
|||||||
"MemorySpanStorage",
|
"MemorySpanStorage",
|
||||||
"FileSpanStorage",
|
"FileSpanStorage",
|
||||||
"root_tracer",
|
"root_tracer",
|
||||||
|
"trace",
|
||||||
"initialize_tracer",
|
"initialize_tracer",
|
||||||
"DefaultTracer",
|
"DefaultTracer",
|
||||||
"TracerManager",
|
"TracerManager",
|
||||||
|
@@ -340,7 +340,7 @@ def chat(
|
|||||||
table.add_row(["echo", metadata.get("echo")])
|
table.add_row(["echo", metadata.get("echo")])
|
||||||
elif "error" in metadata:
|
elif "error" in metadata:
|
||||||
table.add_row(["BaseChat Error", metadata.get("error")])
|
table.add_row(["BaseChat Error", metadata.get("error")])
|
||||||
if op == "BaseChat.nostream_call" and not sp["end_time"]:
|
if op == "BaseChat.do_action" and not sp["end_time"]:
|
||||||
if "model_output" in metadata:
|
if "model_output" in metadata:
|
||||||
table.add_row(
|
table.add_row(
|
||||||
[
|
[
|
||||||
@@ -361,11 +361,18 @@ def chat(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if "prompt_define_response" in metadata:
|
if "prompt_define_response" in metadata:
|
||||||
|
prompt_define_response = metadata.get("prompt_define_response") or ""
|
||||||
|
if isinstance(prompt_define_response, dict) or isinstance(
|
||||||
|
prompt_define_response, type([])
|
||||||
|
):
|
||||||
|
prompt_define_response = json.dumps(
|
||||||
|
prompt_define_response, ensure_ascii=False
|
||||||
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
[
|
[
|
||||||
"BaseChat prompt_define_response",
|
"BaseChat prompt_define_response",
|
||||||
split_string_by_terminal_width(
|
split_string_by_terminal_width(
|
||||||
metadata.get("prompt_define_response"),
|
prompt_define_response,
|
||||||
split=split_long_text,
|
split=split_long_text,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
from pilot.component import SystemApp, ComponentType
|
from pilot.component import SystemApp, ComponentType
|
||||||
from pilot.utils.tracer.base import (
|
from pilot.utils.tracer.base import (
|
||||||
@@ -154,18 +157,42 @@ class TracerManager:
|
|||||||
root_tracer: TracerManager = TracerManager()
|
root_tracer: TracerManager = TracerManager()
|
||||||
|
|
||||||
|
|
||||||
def trace(operation_name: str, **trace_kwargs):
|
def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
def sync_wrapper(*args, **kwargs):
|
||||||
with root_tracer.start_span(operation_name, **trace_kwargs):
|
name = (
|
||||||
|
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||||
|
)
|
||||||
|
with root_tracer.start_span(name, **trace_kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
name = (
|
||||||
|
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||||
|
)
|
||||||
|
with root_tracer.start_span(name, **trace_kwargs):
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_operation_name(func, *args):
|
||||||
|
self_name = None
|
||||||
|
if inspect.signature(func).parameters.get("self"):
|
||||||
|
self_name = args[0].__class__.__name__
|
||||||
|
func_name = func.__name__
|
||||||
|
if self_name:
|
||||||
|
return f"{self_name}.{func_name}"
|
||||||
|
return func_name
|
||||||
|
|
||||||
|
|
||||||
def initialize_tracer(
|
def initialize_tracer(
|
||||||
system_app: SystemApp,
|
system_app: SystemApp,
|
||||||
tracer_filename: str,
|
tracer_filename: str,
|
||||||
|
Reference in New Issue
Block a user