feat(agents): add ReActAgent (#2420)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
This commit is contained in:
Cooper 2025-03-09 20:23:31 +08:00 committed by GitHub
parent a3216a7994
commit 81f4c6a558
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 472 additions and 0 deletions

View File

@ -0,0 +1,89 @@
import asyncio
import logging
import os
import sys
from typing_extensions import Annotated, Doc
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
from dbgpt.agent.expand.react_agent import ReActAgent
from dbgpt.agent.resource import ToolPack, tool
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
@tool
def terminate(
final_answer: Annotated[str, Doc("final literal answer about the goal")],
) -> str:
"""When the goal achieved, this tool must be called."""
return final_answer
@tool
def simple_calculator(first_number: int, second_number: int, operator: str) -> float:
"""Simple calculator tool. Just support +, -, *, /."""
if isinstance(first_number, str):
first_number = int(first_number)
if isinstance(second_number, str):
second_number = int(second_number)
if operator == "+":
return first_number + second_number
elif operator == "-":
return first_number - second_number
elif operator == "*":
return first_number * second_number
elif operator == "/":
return first_number / second_number
else:
raise ValueError(f"Invalid operator: {operator}")
@tool
def count_directory_files(path: Annotated[str, Doc("The directory path")]) -> int:
"""Count the number of files in a directory."""
if not os.path.isdir(path):
raise ValueError(f"Invalid directory path: {path}")
return len(os.listdir(path))
async def main():
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
llm_client = SiliconFlowLLMClient(
model_alias="Qwen/Qwen2-7B-Instruct",
)
agent_memory = AgentMemory()
agent_memory.gpts_memory.init(conv_id="test456")
context: AgentContext = AgentContext(conv_id="test456", gpts_app_name="ReAct")
tools = ToolPack([simple_calculator, count_directory_files, terminate])
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
tool_engineer = (
await ReActAgent(end_action_name="terminate", max_steps=10)
.bind(context)
.bind(LLMConfig(llm_client=llm_client))
.bind(agent_memory)
.bind(tools)
.build()
)
await user_proxy.initiate_chat(
recipient=tool_engineer,
reviewer=user_proxy,
message="Calculate the product of 10 and 99, Count the number of files in /tmp answer in Chinese.",
)
# dbgpt-vis message infos
print(await agent_memory.gpts_memory.app_link_chat_message("test456"))
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,54 @@
import json
import logging
from typing import Optional
from dbgpt.agent import ResourceType
from dbgpt.agent.expand.actions.tool_action import ToolAction, ToolInput
from dbgpt.vis import Vis, VisPlugin
logger = logging.getLogger(__name__)
class ReActAction(ToolAction):
"""ReAct action class."""
def __init__(self, **kwargs):
"""Tool action init."""
super().__init__(**kwargs)
self._render_protocol = VisPlugin()
@property
def resource_need(self) -> Optional[ResourceType]:
"""Return the resource type needed for the action."""
return ResourceType.Tool
@property
def render_protocol(self) -> Optional[Vis]:
"""Return the render protocol."""
return self._render_protocol
@property
def out_model_type(self):
"""Return the output model type."""
return ToolInput
@property
def ai_out_schema(self) -> Optional[str]:
"""Return the AI output schema."""
out_put_schema = {
"Thought": "Summary of thoughts to the user",
"Action": {
"tool_name": "The name of a tool that can be used to answer "
"the current"
"question or solve the current task.",
"args": {
"arg name1": "arg value1",
"arg name2": "arg value2",
},
},
}
return f"""Please response in the following json format:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the response is correct json and can be parsed by Python json.loads.
"""

View File

@ -0,0 +1,329 @@
import json
import logging
from typing import Any, List, Optional, Tuple
from dbgpt.agent import (
ActionOutput,
Agent,
AgentMemoryFragment,
AgentMessage,
ConversableAgent,
ProfileConfig,
ResourceType,
)
from dbgpt.agent.expand.actions.react_action import ReActAction
from dbgpt.core import ModelMessageRoleType
from dbgpt.util.configure import DynConfig
from dbgpt.util.json_utils import find_json_objects
logger = logging.getLogger(__name__)
_REACT_SYSTEM_TEMPLATE = """\
You are a {{ role }}, {% if name %}named {{ name }}.
{% endif %}your goal is {% if is_retry_chat %}{{ retry_goal }}
{% else %}{{ goal }}
{% endif %}.\
At the same time, please strictly abide by the constraints and specifications
in the "IMPORTANT REMINDER" below.
{% if resource_prompt %}\
# ACTION SPACE #
{{ resource_prompt }}
{% endif %}
{% if expand_prompt %}\
{{ expand_prompt }}
{% endif %}\
# IMPORTANT REMINDER #
The current time is:{{now_time}}.
{% if constraints %}\
{% for constraint in constraints %}\
{{ loop.index }}. {{ constraint }}
{% endfor %}\
{% endif %}\
{% if is_retry_chat %}\
{% if retry_constraints %}\
{% for retry_constraint in retry_constraints %}\
{{ loop.index }}. {{ retry_constraint }}
{% endfor %}\
{% endif %}\
{% else %}\
{% endif %}\
{% if examples %}\
# EXAMPLE INTERACTION #
You can refer to the following examples:
{{ examples }}\
{% endif %}\
{% if most_recent_memories %}\
# History of Solving Task#
{{ most_recent_memories }}\
{% endif %}\
# RESPONSE FORMAT #
{% if out_schema %} {{ out_schema }} {% endif %}\
################### TASK ###################
Please solve the task:
"""
_REACT_WRITE_MEMORY_TEMPLATE = """\
{% if question %}Question: {{ question }} {% endif %}
{% if assistant %}Assistant: {{ assistant }} {% endif %}
{% if observation %}Observation: {{ observation }} {% endif %}
"""
class ReActAgent(ConversableAgent):
end_action_name: str = DynConfig(
"terminate",
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_end_action_name",
)
max_steps: int = DynConfig(
10,
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_max_steps",
)
profile: ProfileConfig = ProfileConfig(
name=DynConfig(
"ReAct",
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_name",
),
role=DynConfig(
"ToolMaster",
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_role",
),
goal=DynConfig(
"Read and understand the tool information given in the action space "
"below to understand their capabilities and how to use them,and choosing "
"the right tool to solve the task",
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_goal",
),
constraints=DynConfig(
[
"Achieve the goal step by step."
"Each step, please read the parameter definition of the tool carefully "
"and extract the specific parameters required to execute the tool "
"from the user's goal.",
"information in json format according to the following required format."
"If there is an example, please refer to the sample format output.",
"each step, you can only select one tool in action space.",
],
category="agent",
key="dbgpt_agent_expand_plugin_assistant_agent_constraints",
),
system_prompt_template=_REACT_SYSTEM_TEMPLATE,
write_memory_template=_REACT_WRITE_MEMORY_TEMPLATE,
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([ReActAction])
async def review(self, message: Optional[str], censored: Agent) -> Tuple[bool, Any]:
"""Review the message based on the censored message."""
try:
json_obj = find_json_objects(message)
if len(json_obj) == 0:
raise ValueError(
"No correct json object found in the message。"
"Please strictly output JSON in the defined "
"format, and only one action can be ouput each time. "
)
return True, json_obj[0]
except Exception as e:
logger.error(f"review error: {e}")
raise e
def validate_action(self, action_name: str) -> bool:
tools = self.resource.get_resource_by_type(ResourceType.Tool)
for tool in tools:
if tool.name == action_name:
return True
raise ValueError(f"{action_name} is not in the action space.")
async def generate_reply(
self,
received_message: AgentMessage,
sender: Agent,
reviewer: Optional[Agent] = None,
rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None,
is_retry_chat: bool = False,
last_speaker_name: Optional[str] = None,
**kwargs,
) -> AgentMessage:
"""Generate a reply based on the received messages."""
try:
logger.info(
f"generate agent reply!sender={sender}, "
f"rely_messages_len={rely_messages}"
)
self.validate_action(self.end_action_name)
observation = AgentMessage(content="please start!")
reply_message: AgentMessage = self._init_reply_message(
received_message=received_message
)
thinking_messages, resource_info = await self._load_thinking_messages(
received_message=observation,
sender=sender,
rely_messages=rely_messages,
historical_dialogues=historical_dialogues,
context=reply_message.get_dict_context(),
is_retry_chat=is_retry_chat,
)
# attach current task to system prompt
thinking_messages[0].content = (
thinking_messages[0].content + "\n" + received_message.content
)
done = False
max_steps = self.max_steps
await self.write_memories(
question=received_message.content,
ai_message="",
)
while not done and max_steps > 0:
ai_message = ""
try:
# 1. thinking
llm_reply, model_name = await self.thinking(
thinking_messages, sender
)
reply_message.model_name = model_name
reply_message.resource_info = resource_info
ai_message = llm_reply
thinking_messages.append(
AgentMessage(role=ModelMessageRoleType.AI, content=llm_reply)
)
approve, json_obj = await self.review(llm_reply, self)
logger.info(f"jons_obj: {json_obj}")
action = json_obj["Action"]
thought = json_obj["Thought"]
action.update({"thought": thought})
reply_message.content = json.dumps(action, ensure_ascii=False)
tool_name = action["tool_name"]
self.validate_action(tool_name)
# 2. act
act_extent_param = self.prepare_act_param(
received_message=received_message,
sender=sender,
rely_messages=rely_messages,
historical_dialogues=historical_dialogues,
)
act_out: ActionOutput = await self.act(
message=reply_message,
sender=sender,
reviewer=reviewer,
is_retry_chat=is_retry_chat,
last_speaker_name=last_speaker_name,
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out
# 3. obs
check_pass, reason = await self.verify(
reply_message, sender, reviewer
)
done = tool_name == self.end_action_name and check_pass
if check_pass:
logger.info(f"Observation:{act_out.content}")
thinking_messages.append(
AgentMessage(
role=ModelMessageRoleType.HUMAN,
content=f"Observation: {tool_name} "
f"output:{act_out.content}\n",
)
)
await self.write_memories(
question="",
ai_message=ai_message,
action_output=act_out,
check_pass=check_pass,
)
else:
observation = f"Observation: {reason}"
logger.info(f"Observation:{observation}")
thinking_messages.append(
AgentMessage(
role=ModelMessageRoleType.HUMAN, content=observation
)
)
await self.write_memories(
question="",
ai_message=ai_message,
check_pass=check_pass,
check_fail_reason=reason,
)
max_steps -= 1
except Exception as e:
fail_reason = (
f"Observation: Exception occurs({type(e).__name__}){e}."
)
logger.error(fail_reason)
thinking_messages.append(
AgentMessage(
role=ModelMessageRoleType.HUMAN, content=fail_reason
)
)
await self.write_memories(
question="",
ai_message=ai_message,
check_pass=False,
check_fail_reason=fail_reason,
)
reply_message.success = done
await self.adjust_final_message(True, reply_message)
return reply_message
except Exception as e:
logger.exception("Generate reply exception!")
err_message = AgentMessage(content=str(e))
err_message.success = False
return err_message
async def write_memories(
self,
question: str,
ai_message: str,
action_output: Optional[ActionOutput] = None,
check_pass: bool = True,
check_fail_reason: Optional[str] = None,
) -> None:
"""Write the memories to the memory.
We suggest you to override this method to save the conversation to memory
according to your needs.
Args:
question(str): The question received.
ai_message(str): The AI message, LLM output.
action_output(ActionOutput): The action output.
check_pass(bool): Whether the check pass.
check_fail_reason(str): The check fail reason.
"""
observation = ""
if action_output and action_output.observations:
observation = action_output.observations
elif check_fail_reason:
observation = check_fail_reason
memory_map = {
"question": question,
"assistant": ai_message,
"observation": observation,
}
write_memory_template = self.write_memory_template
memory_content = self._render_template(write_memory_template, **memory_map)
fragment = AgentMemoryFragment(memory_content)
await self.memory.write(fragment)