refactor(agent): Agent modular refactoring (#1487)

This commit is contained in:
Fangyin Cheng
2024-05-07 09:45:26 +08:00
committed by GitHub
parent 2a418f91e8
commit 863b5404dd
86 changed files with 4513 additions and 967 deletions

View File

@@ -3,7 +3,7 @@
import asyncio
import json
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Type, cast
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import LLMClient, ModelMessageRoleType
@@ -11,14 +11,15 @@ from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import SpanType, root_tracer
from dbgpt.util.utils import colored
from ..actions.action import Action, ActionOutput
from ..memory.base import GptsMessage
from ..memory.gpts_memory import GptsMemory
from ..resource.resource_api import AgentResource, ResourceClient
from ..resource.resource_loader import ResourceLoader
from ..util.llm.llm import LLMConfig, LLMStrategyType
from ..util.llm.llm_client import AIWrapper
from .action.base import Action, ActionOutput
from .agent import Agent, AgentContext, AgentMessage, AgentReviewInfo
from .llm.llm import LLMConfig, LLMStrategyType
from .llm.llm_client import AIWrapper
from .memory.agent_memory import AgentMemory
from .memory.gpts.base import GptsMessage
from .memory.gpts.gpts_memory import GptsMemory
from .role import Role
logger = logging.getLogger(__name__)
@@ -33,26 +34,16 @@ class ConversableAgent(Role, Agent):
actions: List[Action] = Field(default_factory=list)
resources: List[AgentResource] = Field(default_factory=list)
llm_config: Optional[LLMConfig] = None
memory: GptsMemory = Field(default_factory=GptsMemory)
resource_loader: Optional[ResourceLoader] = None
max_retry_count: int = 3
consecutive_auto_reply_counter: int = 0
llm_client: Optional[AIWrapper] = None
oai_system_message: List[Dict] = Field(default_factory=list)
def __init__(self, **kwargs):
"""Create a new agent."""
Role.__init__(self, **kwargs)
Agent.__init__(self)
def init_system_message(self) -> None:
"""Initialize the system message."""
content = self.prompt_template()
# TODO: Don't modify the original data, need to be optimized
self.oai_system_message = [
{"content": content, "role": ModelMessageRoleType.SYSTEM}
]
def check_available(self) -> None:
"""Check if the agent is available.
@@ -63,7 +54,7 @@ class ConversableAgent(Role, Agent):
# check run context
if self.agent_context is None:
raise ValueError(
f"{self.name}[{self.profile}] Missing context in which agent is "
f"{self.name}[{self.role}] Missing context in which agent is "
f"running!"
)
@@ -90,20 +81,20 @@ class ConversableAgent(Role, Agent):
and action.resource_need not in have_resource_types
):
raise ValueError(
f"{self.name}[{self.profile}] Missing resources required for "
f"{self.name}[{self.role}] Missing resources required for "
"runtime"
)
else:
if not self.is_human and not self.is_team:
raise ValueError(
f"This agent {self.name}[{self.profile}] is missing action modules."
f"This agent {self.name}[{self.role}] is missing action modules."
)
# llm check
if not self.is_human and (
self.llm_config is None or self.llm_config.llm_client is None
):
raise ValueError(
f"{self.name}[{self.profile}] Model configuration is missing or model "
f"{self.name}[{self.role}] Model configuration is missing or model "
"service is unavailable"
)
@@ -161,14 +152,19 @@ class ConversableAgent(Role, Agent):
for action in self.actions:
action.init_resource_loader(self.resource_loader)
# Initialize system messages
self.init_system_message()
# Initialize LLM Server
if not self.is_human:
if not self.llm_config or not self.llm_config.llm_client:
raise ValueError("LLM client is not initialized")
self.llm_client = AIWrapper(llm_client=self.llm_config.llm_client)
self.memory.initialize(
self.name,
self.llm_config.llm_client,
importance_scorer=self.memory_importance_scorer,
insight_extractor=self.memory_insight_extractor,
)
# Clone the memory structure
self.memory = self.memory.structure_clone()
return self
def bind(self, target: Any) -> "ConversableAgent":
@@ -176,7 +172,7 @@ class ConversableAgent(Role, Agent):
if isinstance(target, LLMConfig):
self.llm_config = target
elif isinstance(target, GptsMemory):
self.memory = target
raise ValueError("GptsMemory is not supported!")
elif isinstance(target, AgentContext):
self.agent_context = target
elif isinstance(target, ResourceLoader):
@@ -186,6 +182,8 @@ class ConversableAgent(Role, Agent):
self.actions.extend(target)
elif _is_list_of_type(target, AgentResource):
self.resources = target
elif isinstance(target, AgentMemory):
self.memory = target
return self
async def send(
@@ -200,9 +198,9 @@ class ConversableAgent(Role, Agent):
with root_tracer.start_span(
"agent.send",
metadata={
"sender": self.get_name(),
"recipient": recipient.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": self.name,
"recipient": recipient.name,
"reviewer": reviewer.name if reviewer else None,
"agent_message": message.to_dict(),
"request_reply": request_reply,
"is_recovery": is_recovery,
@@ -230,9 +228,9 @@ class ConversableAgent(Role, Agent):
with root_tracer.start_span(
"agent.receive",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": sender.name,
"recipient": self.name,
"reviewer": reviewer.name if reviewer else None,
"agent_message": message.to_dict(),
"request_reply": request_reply,
"silent": silent,
@@ -271,14 +269,14 @@ class ConversableAgent(Role, Agent):
root_span = root_tracer.start_span(
"agent.generate_reply",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": sender.name,
"recipient": self.name,
"reviewer": reviewer.name if reviewer else None,
"received_message": received_message.to_dict(),
"conv_uid": self.not_null_agent_context.conv_id,
"rely_messages": [msg.to_dict() for msg in rely_messages]
if rely_messages
else None,
"rely_messages": (
[msg.to_dict() for msg in rely_messages] if rely_messages else None
),
},
)
@@ -295,18 +293,6 @@ class ConversableAgent(Role, Agent):
)
span.metadata["reply_message"] = reply_message.to_dict()
with root_tracer.start_span(
"agent.generate_reply._system_message_assembly",
metadata={
"reply_message": reply_message.to_dict(),
},
) as span:
# assemble system message
await self._system_message_assembly(
received_message.content, reply_message.context
)
span.metadata["assembled_system_messages"] = self.oai_system_message
fail_reason = None
current_retry_counter = 0
is_success = True
@@ -325,8 +311,11 @@ class ConversableAgent(Role, Agent):
retry_message, self, reviewer, request_reply=False
)
thinking_messages = self._load_thinking_messages(
received_message, sender, rely_messages
thinking_messages = await self._load_thinking_messages(
received_message,
sender,
rely_messages,
context=reply_message.get_dict_context(),
)
with root_tracer.start_span(
"agent.generate_reply.thinking",
@@ -345,7 +334,7 @@ class ConversableAgent(Role, Agent):
with root_tracer.start_span(
"agent.generate_reply.review",
metadata={"llm_reply": llm_reply, "censored": self.get_name()},
metadata={"llm_reply": llm_reply, "censored": self.name},
) as span:
# 2.Review whether what is being done is legal
approve, comments = await self.review(llm_reply, self)
@@ -361,8 +350,8 @@ class ConversableAgent(Role, Agent):
"agent.generate_reply.act",
metadata={
"llm_reply": llm_reply,
"sender": sender.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": sender.name,
"reviewer": reviewer.name if reviewer else None,
"act_extent_param": act_extent_param,
},
) as span:
@@ -383,8 +372,8 @@ class ConversableAgent(Role, Agent):
"agent.generate_reply.verify",
metadata={
"llm_reply": llm_reply,
"sender": sender.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": sender.name,
"reviewer": reviewer.name if reviewer else None,
},
) as span:
# 4.Reply information verification
@@ -394,6 +383,9 @@ class ConversableAgent(Role, Agent):
is_success = check_pass
span.metadata["check_pass"] = check_pass
span.metadata["reason"] = reason
question: str = received_message.content or ""
ai_message: str = llm_reply or ""
# 5.Optimize wrong answers myself
if not check_pass:
current_retry_counter += 1
@@ -403,7 +395,20 @@ class ConversableAgent(Role, Agent):
reply_message, sender, reviewer, request_reply=False
)
fail_reason = reason
await self.save_to_memory(
question=question,
ai_message=ai_message,
action_output=act_out,
check_pass=check_pass,
check_fail_reason=fail_reason,
)
else:
await self.save_to_memory(
question=question,
ai_message=ai_message,
action_output=act_out,
check_pass=check_pass,
)
break
reply_message.success = is_success
return reply_message
@@ -437,8 +442,6 @@ class ConversableAgent(Role, Agent):
try:
if prompt:
llm_messages = _new_system_message(prompt) + llm_messages
else:
llm_messages = self.oai_system_message + llm_messages
if not self.llm_client:
raise ValueError("LLM client is not initialized!")
@@ -491,9 +494,9 @@ class ConversableAgent(Role, Agent):
"agent.act.run",
metadata={
"message": message,
"sender": sender.get_name() if sender else None,
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": sender.name if sender else None,
"recipient": self.name,
"reviewer": reviewer.name if reviewer else None,
"need_resource": need_resource.to_dict() if need_resource else None,
"rely_action_out": last_out.to_dict() if last_out else None,
"conv_uid": self.not_null_agent_context.conv_id,
@@ -563,9 +566,9 @@ class ConversableAgent(Role, Agent):
"agent.initiate_chat",
span_type=SpanType.AGENT,
metadata={
"sender": self.get_name(),
"recipient": recipient.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"sender": self.name,
"recipient": recipient.name,
"reviewer": reviewer.name if reviewer else None,
"agent_message": agent_message.to_dict(),
"conv_uid": self.not_null_agent_context.conv_id,
},
@@ -612,21 +615,27 @@ class ConversableAgent(Role, Agent):
gpts_message: GptsMessage = GptsMessage(
conv_id=self.not_null_agent_context.conv_id,
sender=sender.get_profile(),
receiver=self.profile,
sender=sender.role,
receiver=self.role,
role=role,
rounds=self.consecutive_auto_reply_counter,
current_goal=oai_message.get("current_goal", None),
content=oai_message.get("content", None),
context=json.dumps(oai_message["context"], ensure_ascii=False)
if "context" in oai_message
else None,
review_info=json.dumps(oai_message["review_info"], ensure_ascii=False)
if "review_info" in oai_message
else None,
action_report=json.dumps(oai_message["action_report"], ensure_ascii=False)
if "action_report" in oai_message
else None,
context=(
json.dumps(oai_message["context"], ensure_ascii=False)
if "context" in oai_message
else None
),
review_info=(
json.dumps(oai_message["review_info"], ensure_ascii=False)
if "review_info" in oai_message
else None
),
action_report=(
json.dumps(oai_message["action_report"], ensure_ascii=False)
if "action_report" in oai_message
else None
),
model_name=oai_message.get("model_name", None),
)
@@ -643,10 +652,10 @@ class ConversableAgent(Role, Agent):
def _print_received_message(self, message: AgentMessage, sender: Agent):
# print the message received
print("\n", "-" * 80, flush=True, sep="")
_print_name = self.name if self.name else self.profile
_print_name = self.name if self.name else self.role
print(
colored(
sender.get_name() if sender.get_name() else sender.get_profile(),
sender.name if sender.name else sender.role,
"yellow",
),
"(to",
@@ -660,7 +669,7 @@ class ConversableAgent(Role, Agent):
review_info = message.review_info
if review_info:
name = sender.get_name() if sender.get_name() else sender.get_profile()
name = sender.name if sender.name else sender.role
pass_msg = "Pass" if review_info.approve else "Reject"
review_msg = f"{pass_msg}({review_info.comments})"
approve_print = f">>>>>>>>{name} Review info: \n{review_msg}"
@@ -668,7 +677,7 @@ class ConversableAgent(Role, Agent):
action_report = message.action_report
if action_report:
name = sender.get_name() if sender.get_name() else sender.get_profile()
name = sender.name if sender.name else sender.role
action_msg = (
"execution succeeded"
if action_report["is_exe_success"]
@@ -690,42 +699,32 @@ class ConversableAgent(Role, Agent):
self._print_received_message(message, sender)
async def _system_message_assembly(
self, question: Optional[str], context: Optional[Union[str, Dict]] = None
):
# system message
self.init_system_message()
if len(self.oai_system_message) > 0:
resource_prompt_list = []
for item in self.resources:
resource_client = self.not_null_resource_loader.get_resource_api(
item.type, ResourceClient
async def generate_resource_variables(
self, question: Optional[str] = None
) -> Dict[str, Any]:
"""Generate the resource variables."""
resource_prompt_list = []
for item in self.resources:
resource_client = self.not_null_resource_loader.get_resource_api(
item.type, ResourceClient
)
if not resource_client:
raise ValueError(
f"Resource {item.type}:{item.value} missing resource loader"
f" implementation,unable to read resources!"
)
if not resource_client:
raise ValueError(
f"Resource {item.type}:{item.value} missing resource loader"
f" implementation,unable to read resources!"
)
resource_prompt_list.append(
await resource_client.get_resource_prompt(item, question)
)
if context is None or not isinstance(context, dict):
context = {}
resource_prompt_list.append(
await resource_client.get_resource_prompt(item, question)
)
resource_prompt = ""
if len(resource_prompt_list) > 0:
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
resource_prompt = ""
if len(resource_prompt_list) > 0:
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
out_schema: Optional[str] = ""
if self.actions and len(self.actions) > 0:
out_schema = self.actions[0].ai_out_schema
for message in self.oai_system_message:
new_content = message["content"].format(
resource_prompt=resource_prompt,
out_schema=out_schema,
**context,
)
message["content"] = new_content
out_schema: Optional[str] = ""
if self.actions and len(self.actions) > 0:
out_schema = self.actions[0].ai_out_schema
return {"resource_prompt": resource_prompt, "out_schema": out_schema}
def _excluded_models(
self,
@@ -774,7 +773,7 @@ class ConversableAgent(Role, Agent):
else:
raise ValueError("No model service available!")
except Exception as e:
logger.error(f"{self.profile} get next llm failed!{str(e)}")
logger.error(f"{self.role} get next llm failed!{str(e)}")
raise ValueError(f"Failed to allocate model service,{str(e)}!")
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
@@ -803,9 +802,9 @@ class ConversableAgent(Role, Agent):
if item.role:
role = item.role
else:
if item.receiver == self.profile:
if item.receiver == self.role:
role = ModelMessageRoleType.HUMAN
elif item.sender == self.profile:
elif item.sender == self.role:
role = ModelMessageRoleType.AI
else:
continue
@@ -825,14 +824,80 @@ class ConversableAgent(Role, Agent):
AgentMessage(
content=content,
role=role,
context=json.loads(item.context)
if item.context is not None
else None,
context=(
json.loads(item.context) if item.context is not None else None
),
)
)
return oai_messages
def _load_thinking_messages(
async def _load_thinking_messages(
self,
received_message: AgentMessage,
sender: Agent,
rely_messages: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None,
) -> List[AgentMessage]:
observation = received_message.content
if not observation:
raise ValueError("The received message content is empty!")
memories = await self.read_memories(observation)
reply_message_str = ""
if context is None:
context = {}
if rely_messages:
copied_rely_messages = [m.copy() for m in rely_messages]
# When directly relying on historical messages, use the execution result
# content as a dependency
for message in copied_rely_messages:
action_report: Optional[ActionOutput] = ActionOutput.from_dict(
message.action_report
)
if action_report:
# TODO: Modify in-place, need to be optimized
message.content = action_report.content
if message.name != self.role:
# TODO, use name
# Rely messages are not from the current agent
if message.role == ModelMessageRoleType.HUMAN:
reply_message_str += f"Question: {message.content}\n"
elif message.role == ModelMessageRoleType.AI:
reply_message_str += f"Observation: {message.content}\n"
if reply_message_str:
memories += "\n" + reply_message_str
system_prompt = await self.build_prompt(
question=observation,
is_system=True,
most_recent_memories=memories,
**context,
)
user_prompt = await self.build_prompt(
question=observation,
is_system=False,
most_recent_memories=memories,
**context,
)
agent_messages = []
if system_prompt:
agent_messages.append(
AgentMessage(
content=system_prompt,
role=ModelMessageRoleType.SYSTEM,
)
)
if user_prompt:
agent_messages.append(
AgentMessage(
content=user_prompt,
role=ModelMessageRoleType.HUMAN,
)
)
return agent_messages
def _old_load_thinking_messages(
self,
received_message: AgentMessage,
sender: Agent,
@@ -846,8 +911,8 @@ class ConversableAgent(Role, Agent):
with root_tracer.start_span(
"agent._load_thinking_messages",
metadata={
"sender": sender.get_name(),
"recipient": self.get_name(),
"sender": sender.name,
"recipient": self.name,
"conv_uid": self.not_null_agent_context.conv_id,
"current_goal": current_goal,
},
@@ -855,8 +920,8 @@ class ConversableAgent(Role, Agent):
# Get historical information from the memory
memory_messages = self.memory.message_memory.get_between_agents(
self.not_null_agent_context.conv_id,
self.profile,
sender.get_profile(),
self.role,
sender.role,
current_goal,
)
span.metadata["memory_messages"] = [