feat(agent): Add trace for agent (#1407)

This commit is contained in:
Fangyin Cheng
2024-04-11 19:07:06 +08:00
committed by GitHub
parent 7d6dfd9ea8
commit aea575e0b4
19 changed files with 1126 additions and 89 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from dbgpt._private.pydantic import Field
from dbgpt.core import LLMClient, ModelMessageRoleType
from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.util.utils import colored
from ..actions.action import Action, ActionOutput
@@ -199,13 +200,25 @@ class ConversableAgent(Role, Agent):
is_recovery: Optional[bool] = False,
) -> None:
"""Send a message to recipient agent."""
await recipient.receive(
message=message,
sender=self,
reviewer=reviewer,
request_reply=request_reply,
is_recovery=is_recovery,
)
with root_tracer.start_span(
"agent.send",
metadata={
"sender": self.get_name(),
"recipient": recipient.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"agent_message": message.to_dict(),
"request_reply": request_reply,
"is_recovery": is_recovery,
"conv_uid": self.not_null_agent_context.conv_id,
},
):
await recipient.receive(
message=message,
sender=self,
reviewer=reviewer,
request_reply=request_reply,
is_recovery=is_recovery,
)
async def receive(
self,
@@ -217,16 +230,30 @@ class ConversableAgent(Role, Agent):
is_recovery: Optional[bool] = False,
) -> None:
"""Receive a message from another agent."""
await self._a_process_received_message(message, sender)
if request_reply is False or request_reply is None:
return
with root_tracer.start_span(
"agent.receive",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"agent_message": message.to_dict(),
"request_reply": request_reply,
"silent": silent,
"is_recovery": is_recovery,
"conv_uid": self.not_null_agent_context.conv_id,
"is_human": self.is_human,
},
):
await self._a_process_received_message(message, sender)
if request_reply is False or request_reply is None:
return
if not self.is_human:
reply = await self.generate_reply(
received_message=message, sender=sender, reviewer=reviewer
)
if reply is not None:
await self.send(reply, sender)
if not self.is_human:
reply = await self.generate_reply(
received_message=message, sender=sender, reviewer=reviewer
)
if reply is not None:
await self.send(reply, sender)
def prepare_act_param(self) -> Dict[str, Any]:
"""Prepare the parameters for the act method."""
@@ -244,13 +271,44 @@ class ConversableAgent(Role, Agent):
logger.info(
f"generate agent reply!sender={sender}, rely_messages_len={rely_messages}"
)
root_span = root_tracer.start_span(
"agent.generate_reply",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"received_message": received_message.to_dict(),
"conv_uid": self.not_null_agent_context.conv_id,
"rely_messages": [msg.to_dict() for msg in rely_messages]
if rely_messages
else None,
},
)
try:
reply_message: AgentMessage = self._init_reply_message(
received_message=received_message
)
await self._system_message_assembly(
received_message.content, reply_message.context
)
with root_tracer.start_span(
"agent.generate_reply._init_reply_message",
metadata={
"received_message": received_message.to_dict(),
},
) as span:
# initialize reply message
reply_message: AgentMessage = self._init_reply_message(
received_message=received_message
)
span.metadata["reply_message"] = reply_message.to_dict()
with root_tracer.start_span(
"agent.generate_reply._system_message_assembly",
metadata={
"reply_message": reply_message.to_dict(),
},
) as span:
# assemble system message
await self._system_message_assembly(
received_message.content, reply_message.context
)
span.metadata["assembled_system_messages"] = self.oai_system_message
fail_reason = None
current_retry_counter = 0
@@ -270,36 +328,73 @@ class ConversableAgent(Role, Agent):
retry_message, self, reviewer, request_reply=False
)
# 1.Think about how to do things
llm_reply, model_name = await self.thinking(
self._load_thinking_messages(
received_message, sender, rely_messages
thinking_messages = self._load_thinking_messages(
received_message, sender, rely_messages
)
with root_tracer.start_span(
"agent.generate_reply.thinking",
metadata={
"thinking_messages": [
msg.to_dict() for msg in thinking_messages
],
},
) as span:
# 1.Think about how to do things
llm_reply, model_name = await self.thinking(thinking_messages)
reply_message.model_name = model_name
reply_message.content = llm_reply
span.metadata["llm_reply"] = llm_reply
span.metadata["model_name"] = model_name
with root_tracer.start_span(
"agent.generate_reply.review",
metadata={"llm_reply": llm_reply, "censored": self.get_name()},
) as span:
# 2.Review whether what is being done is legal
approve, comments = await self.review(llm_reply, self)
reply_message.review_info = AgentReviewInfo(
approve=approve,
comments=comments,
)
)
reply_message.model_name = model_name
reply_message.content = llm_reply
span.metadata["approve"] = approve
span.metadata["comments"] = comments
# 2.Review whether what is being done is legal
approve, comments = await self.review(llm_reply, self)
reply_message.review_info = AgentReviewInfo(
approve=approve,
comments=comments,
)
# 3.Act based on the results of your thinking
act_extent_param = self.prepare_act_param()
act_out: Optional[ActionOutput] = await self.act(
message=llm_reply,
sender=sender,
reviewer=reviewer,
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out.dict()
with root_tracer.start_span(
"agent.generate_reply.act",
metadata={
"llm_reply": llm_reply,
"sender": sender.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"act_extent_param": act_extent_param,
},
) as span:
# 3.Act based on the results of your thinking
act_out: Optional[ActionOutput] = await self.act(
message=llm_reply,
sender=sender,
reviewer=reviewer,
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out.dict()
span.metadata["action_report"] = act_out.dict() if act_out else None
# 4.Reply information verification
check_pass, reason = await self.verify(reply_message, sender, reviewer)
is_success = check_pass
with root_tracer.start_span(
"agent.generate_reply.verify",
metadata={
"llm_reply": llm_reply,
"sender": sender.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
},
) as span:
# 4.Reply information verification
check_pass, reason = await self.verify(
reply_message, sender, reviewer
)
is_success = check_pass
span.metadata["check_pass"] = check_pass
span.metadata["reason"] = reason
# 5.Optimize wrong answers myself
if not check_pass:
current_retry_counter += 1
@@ -319,6 +414,9 @@ class ConversableAgent(Role, Agent):
err_message = AgentMessage(content=str(e))
err_message.success = False
return err_message
finally:
root_span.metadata["reply_message"] = reply_message.to_dict()
root_span.end()
async def thinking(
self, messages: List[AgentMessage], prompt: Optional[str] = None
@@ -378,7 +476,7 @@ class ConversableAgent(Role, Agent):
) -> Optional[ActionOutput]:
"""Perform actions."""
last_out: Optional[ActionOutput] = None
for action in self.actions:
for i, action in enumerate(self.actions):
# Select the resources required by acton
need_resource = None
if self.resources and len(self.resources) > 0:
@@ -390,12 +488,27 @@ class ConversableAgent(Role, Agent):
if not message:
raise ValueError("The message content is empty!")
last_out = await action.run(
ai_message=message,
resource=need_resource,
rely_action_out=last_out,
**kwargs,
)
with root_tracer.start_span(
"agent.act.run",
metadata={
"message": message,
"sender": sender.get_name() if sender else None,
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"need_resource": need_resource.to_dict() if need_resource else None,
"rely_action_out": last_out.dict() if last_out else None,
"conv_uid": self.not_null_agent_context.conv_id,
"action_index": i,
"total_action": len(self.actions),
},
) as span:
last_out = await action.run(
ai_message=message,
resource=need_resource,
rely_action_out=last_out,
**kwargs,
)
span.metadata["action_out"] = last_out.dict() if last_out else None
return last_out
async def correctness_check(
@@ -446,12 +559,24 @@ class ConversableAgent(Role, Agent):
reviewer (Agent): The reviewer agent.
message (str): The message to send.
"""
await self.send(
AgentMessage(content=message, current_goal=message),
recipient,
reviewer,
request_reply=True,
)
agent_message = AgentMessage(content=message, current_goal=message)
with root_tracer.start_span(
"agent.initiate_chat",
span_type=SpanType.AGENT,
metadata={
"sender": self.get_name(),
"recipient": recipient.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"agent_message": agent_message.to_dict(),
"conv_uid": self.not_null_agent_context.conv_id,
},
):
await self.send(
agent_message,
recipient,
reviewer,
request_reply=True,
)
#######################################################################
# Private Function Begin
@@ -506,8 +631,15 @@ class ConversableAgent(Role, Agent):
model_name=oai_message.get("model_name", None),
)
self.memory.message_memory.append(gpts_message)
return True
with root_tracer.start_span(
"agent.save_message_to_memory",
metadata={
"gpts_message": gpts_message.to_dict(),
"conv_uid": self.not_null_agent_context.conv_id,
},
):
self.memory.message_memory.append(gpts_message)
return True
def _print_received_message(self, message: AgentMessage, sender: Agent):
# print the message received
@@ -711,14 +843,27 @@ class ConversableAgent(Role, Agent):
# Convert and tailor the information in collective memory into contextual
# memory available to the current Agent
current_goal_messages = self._convert_to_ai_message(
self.memory.message_memory.get_between_agents(
with root_tracer.start_span(
"agent._load_thinking_messages",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"conv_uid": self.not_null_agent_context.conv_id,
"current_goal": current_goal,
},
) as span:
# Get historical information from the memory
memory_messages = self.memory.message_memory.get_between_agents(
self.not_null_agent_context.conv_id,
self.profile,
sender.get_profile(),
current_goal,
)
)
span.metadata["memory_messages"] = [
message.to_dict() for message in memory_messages
]
current_goal_messages = self._convert_to_ai_message(memory_messages)
# When there is no target and context, the current received message is used as
# the target problem

View File

@@ -25,7 +25,17 @@ CHECK_RESULT_SYSTEM_MESSAGE = (
"such as: True.\n"
" If it is determined to be a failure, return false and the reason, "
"such as: False. There are no numbers in the execution results that answer the "
"computational goals of the mission."
"computational goals of the mission.\n"
"You can refer to the following examples:\n"
"user: Please understand the following task objectives and results and give your "
"judgment:\nTask goal: Calculate the result of 1 + 2 using Python code.\n"
"Execution Result: 3\n"
"assistant: True\n\n"
"user: Please understand the following task objectives and results and give your "
"judgment:\nTask goal: Calculate the result of 100 * 10 using Python code.\n"
"Execution Result: 'you can get the result by multiplying 100 by 10'\n"
"assistant: False. There are no numbers in the execution results that answer the "
"computational goals of the mission.\n"
)

View File

@@ -1,7 +1,7 @@
"""Database resource client API."""
import logging
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from .resource_api import AgentResource, ResourceClient, ResourceType
@@ -48,7 +48,8 @@ class ResourceDbClient(ResourceClient):
class SqliteLoadClient(ResourceDbClient):
"""SQLite resource client."""
from sqlalchemy.orm.session import Session
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
def __init__(self):
"""Create a SQLite resource client."""
@@ -59,7 +60,7 @@ class SqliteLoadClient(ResourceDbClient):
return "sqlite"
@contextmanager
def connect(self, db) -> Iterator[Session]:
def connect(self, db) -> Iterator["Session"]:
"""Connect to the database."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

View File

@@ -1 +1,4 @@
"""Utils for the agent module."""
from .cmp import cmp_string_equal # noqa: F401
__ALL__ = ["cmp_string_equal"]

View File

@@ -223,8 +223,8 @@ def run_webserver(param: WebServerParameters = None):
if not param:
param = _get_webserver_params()
initialize_tracer(
system_app,
os.path.join(LOGDIR, param.tracer_file),
system_app=system_app,
tracer_storage_cls=param.tracer_storage_cls,
)

View File

@@ -5,8 +5,10 @@ Manages the lifecycle and registration of components.
from __future__ import annotations
import asyncio
import atexit
import logging
import sys
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union
@@ -162,6 +164,8 @@ class SystemApp(LifeCycle):
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
self._app_config = app_config or AppConfig()
self._stop_event = threading.Event()
self._stop_event.clear()
self._build()
@property
@@ -273,11 +277,14 @@ class SystemApp(LifeCycle):
def before_stop(self):
"""Invoke the before_stop hooks for all registered components."""
if self._stop_event.is_set():
return
for _, v in self.components.items():
try:
v.before_stop()
except Exception as e:
pass
self._stop_event.set()
async def async_before_stop(self):
"""Asynchronously invoke the before_stop hooks for all registered components."""
@@ -287,6 +294,7 @@ class SystemApp(LifeCycle):
def _build(self):
"""Integrate lifecycle events with the internal ASGI app if available."""
if not self.app:
self._register_exit_handler()
return
@self.app.on_event("startup")
@@ -308,3 +316,7 @@ class SystemApp(LifeCycle):
"""ASGI app shutdown event handler."""
await self.async_before_stop()
self.before_stop()
def _register_exit_handler(self):
"""Register an exit handler to stop the system app."""
atexit.register(self.before_stop)

View File

@@ -1126,8 +1126,8 @@ def run_worker_manager(
system_app = SystemApp(app)
initialize_tracer(
system_app,
os.path.join(LOGDIR, worker_params.tracer_file),
system_app=system_app,
root_operation_name="DB-GPT-WorkerManager-Entry",
tracer_storage_cls=worker_params.tracer_storage_cls,
)

View File

@@ -15,6 +15,7 @@ class SpanType(str, Enum):
BASE = "base"
RUN = "run"
CHAT = "chat"
AGENT = "agent"
class SpanTypeRunName(str, Enum):
@@ -99,6 +100,21 @@ class Span:
"metadata": _clean_for_json(self.metadata),
}
def copy(self) -> Span:
"""Create a copy of this span."""
metadata = self.metadata.copy() if self.metadata else None
span = Span(
self.trace_id,
self.span_id,
self.span_type,
self.parent_span_id,
self.operation_name,
metadata=metadata,
)
span.start_time = self.start_time
span.end_time = self.end_time
return span
class SpanStorageType(str, Enum):
ON_CREATE = "on_create"
@@ -191,7 +207,7 @@ class TracerContext:
def _clean_for_json(data: Optional[str, Any] = None):
if not data:
if data is None:
return None
if isinstance(data, dict):
cleaned_dict = {}

View File

@@ -49,7 +49,9 @@ class SpanStorageContainer(SpanStorage):
self.flush_thread = threading.Thread(
target=self._flush_to_storages, daemon=True
)
self._stop_event = threading.Event()
self.flush_thread.start()
self._stop_event.clear()
def append_storage(self, storage: SpanStorage):
"""Append sotrage to container
@@ -68,7 +70,7 @@ class SpanStorageContainer(SpanStorage):
pass # If the signal queue is full, it's okay. The flush thread will handle it.
def _flush_to_storages(self):
while True:
while not self._stop_event.is_set():
interval = time.time() - self.last_flush_time
if interval < self.flush_interval:
try:
@@ -90,13 +92,24 @@ class SpanStorageContainer(SpanStorage):
try:
storage.append_span_batch(spans_to_write)
except Exception as e:
logger.warn(
logger.warning(
f"Append spans to storage {str(storage)} failed: {str(e)}, span_data: {spans_to_write}"
)
self.executor.submit(append_and_ignore_error, s, spans_to_write)
try:
self.executor.submit(append_and_ignore_error, s, spans_to_write)
except RuntimeError:
append_and_ignore_error(s, spans_to_write)
self.last_flush_time = time.time()
def before_stop(self):
try:
self.flush_signal_queue.put(True)
self._stop_event.set()
self.flush_thread.join()
except Exception:
pass
class FileSpanStorage(SpanStorage):
def __init__(self, filename: str):

View File

@@ -52,7 +52,7 @@ def test_start_and_end_span(tracer: Tracer):
assert span.end_time is not None
stored_span = tracer._get_current_storage().spans[0]
assert stored_span == span
assert stored_span.span_id == span.span_id
def test_start_and_end_span_with_tracer_manager(tracer_manager: TracerManager):
@@ -76,8 +76,12 @@ def test_parent_child_span_relation(tracer: Tracer):
tracer.end_span(child_span)
tracer.end_span(parent_span)
assert parent_span in tracer._get_current_storage().spans
assert child_span in tracer._get_current_storage().spans
assert parent_span.operation_name in [
s.operation_name for s in tracer._get_current_storage().spans
]
assert child_span.operation_name in [
s.operation_name for s in tracer._get_current_storage().spans
]
@pytest.mark.parametrize(

View File

@@ -36,7 +36,7 @@ class DefaultTracer(Tracer):
self._span_storage_type = span_storage_type
def append_span(self, span: Span):
self._get_current_storage().append_span(span)
self._get_current_storage().append_span(span.copy())
def start_span(
self,
@@ -130,9 +130,13 @@ class TracerManager:
"""
tracer = self._get_tracer()
if not tracer:
return Span("empty_span", "empty_span")
return Span(
"empty_span", "empty_span", span_type=span_type, metadata=metadata
)
if not parent_span_id:
parent_span_id = self.get_current_span_id()
if not span_type and parent_span_id:
span_type = self._get_current_span_type()
return tracer.start_span(
operation_name, parent_span_id, span_type=span_type, metadata=metadata
)
@@ -156,6 +160,10 @@ class TracerManager:
ctx = self._trace_context_var.get()
return ctx.span_id if ctx else None
def _get_current_span_type(self) -> Optional[SpanType]:
current_span = self.get_current_span()
return current_span.span_type if current_span else None
root_tracer: TracerManager = TracerManager()
@@ -197,14 +205,19 @@ def _parse_operation_name(func, *args):
def initialize_tracer(
system_app: SystemApp,
tracer_filename: str,
root_operation_name: str = "DB-GPT-Web-Entry",
tracer_storage_cls: str = None,
system_app: Optional[SystemApp] = None,
tracer_storage_cls: Optional[str] = None,
create_system_app: bool = False,
):
"""Initialize the tracer with the given filename and system app."""
from dbgpt.util.tracer.span_storage import FileSpanStorage, SpanStorageContainer
if not system_app and create_system_app:
system_app = SystemApp()
if not system_app:
return
from dbgpt.util.tracer.span_storage import FileSpanStorage, SpanStorageContainer
trace_context_var = ContextVar(
"trace_context",

View File

@@ -1,5 +1,6 @@
"""GPT-Vis Module."""
from .base import Vis # noqa: F401
from .client import vis_client # noqa: F401
from .tags.vis_agent_message import VisAgentMessages # noqa: F401
from .tags.vis_agent_plans import VisAgentPlans # noqa: F401
@@ -9,6 +10,7 @@ from .tags.vis_dashboard import VisDashboard # noqa: F401
from .tags.vis_plugin import VisPlugin # noqa: F401
__ALL__ = [
"Vis",
"vis_client",
"VisAgentMessages",
"VisAgentPlans",