mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
feat(agents): add ReActAgent (#2420)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
This commit is contained in:
parent
a3216a7994
commit
81f4c6a558
89
examples/agents/react_agent_example.py
Normal file
89
examples/agents/react_agent_example.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from typing_extensions import Annotated, Doc
|
||||||
|
|
||||||
|
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
||||||
|
from dbgpt.agent.expand.react_agent import ReActAgent
|
||||||
|
from dbgpt.agent.resource import ToolPack, tool
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
stream=sys.stdout,
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def terminate(
|
||||||
|
final_answer: Annotated[str, Doc("final literal answer about the goal")],
|
||||||
|
) -> str:
|
||||||
|
"""When the goal achieved, this tool must be called."""
|
||||||
|
return final_answer
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def simple_calculator(first_number: int, second_number: int, operator: str) -> float:
|
||||||
|
"""Simple calculator tool. Just support +, -, *, /."""
|
||||||
|
if isinstance(first_number, str):
|
||||||
|
first_number = int(first_number)
|
||||||
|
if isinstance(second_number, str):
|
||||||
|
second_number = int(second_number)
|
||||||
|
if operator == "+":
|
||||||
|
return first_number + second_number
|
||||||
|
elif operator == "-":
|
||||||
|
return first_number - second_number
|
||||||
|
elif operator == "*":
|
||||||
|
return first_number * second_number
|
||||||
|
elif operator == "/":
|
||||||
|
return first_number / second_number
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid operator: {operator}")
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def count_directory_files(path: Annotated[str, Doc("The directory path")]) -> int:
|
||||||
|
"""Count the number of files in a directory."""
|
||||||
|
if not os.path.isdir(path):
|
||||||
|
raise ValueError(f"Invalid directory path: {path}")
|
||||||
|
return len(os.listdir(path))
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
|
||||||
|
|
||||||
|
llm_client = SiliconFlowLLMClient(
|
||||||
|
model_alias="Qwen/Qwen2-7B-Instruct",
|
||||||
|
)
|
||||||
|
agent_memory = AgentMemory()
|
||||||
|
agent_memory.gpts_memory.init(conv_id="test456")
|
||||||
|
|
||||||
|
context: AgentContext = AgentContext(conv_id="test456", gpts_app_name="ReAct")
|
||||||
|
|
||||||
|
tools = ToolPack([simple_calculator, count_directory_files, terminate])
|
||||||
|
|
||||||
|
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
|
||||||
|
|
||||||
|
tool_engineer = (
|
||||||
|
await ReActAgent(end_action_name="terminate", max_steps=10)
|
||||||
|
.bind(context)
|
||||||
|
.bind(LLMConfig(llm_client=llm_client))
|
||||||
|
.bind(agent_memory)
|
||||||
|
.bind(tools)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
await user_proxy.initiate_chat(
|
||||||
|
recipient=tool_engineer,
|
||||||
|
reviewer=user_proxy,
|
||||||
|
message="Calculate the product of 10 and 99, Count the number of files in /tmp, answer in Chinese.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# dbgpt-vis message infos
|
||||||
|
print(await agent_memory.gpts_memory.app_link_chat_message("test456"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dbgpt.agent import ResourceType
|
||||||
|
from dbgpt.agent.expand.actions.tool_action import ToolAction, ToolInput
|
||||||
|
from dbgpt.vis import Vis, VisPlugin
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReActAction(ToolAction):
|
||||||
|
"""ReAct action class."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Tool action init."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._render_protocol = VisPlugin()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resource_need(self) -> Optional[ResourceType]:
|
||||||
|
"""Return the resource type needed for the action."""
|
||||||
|
return ResourceType.Tool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def render_protocol(self) -> Optional[Vis]:
|
||||||
|
"""Return the render protocol."""
|
||||||
|
return self._render_protocol
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_model_type(self):
|
||||||
|
"""Return the output model type."""
|
||||||
|
return ToolInput
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ai_out_schema(self) -> Optional[str]:
|
||||||
|
"""Return the AI output schema."""
|
||||||
|
out_put_schema = {
|
||||||
|
"Thought": "Summary of thoughts to the user",
|
||||||
|
"Action": {
|
||||||
|
"tool_name": "The name of a tool that can be used to answer "
|
||||||
|
"the current"
|
||||||
|
"question or solve the current task.",
|
||||||
|
"args": {
|
||||||
|
"arg name1": "arg value1",
|
||||||
|
"arg name2": "arg value2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"""Please response in the following json format:
|
||||||
|
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||||
|
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||||
|
"""
|
329
packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py
Normal file
329
packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from dbgpt.agent import (
|
||||||
|
ActionOutput,
|
||||||
|
Agent,
|
||||||
|
AgentMemoryFragment,
|
||||||
|
AgentMessage,
|
||||||
|
ConversableAgent,
|
||||||
|
ProfileConfig,
|
||||||
|
ResourceType,
|
||||||
|
)
|
||||||
|
from dbgpt.agent.expand.actions.react_action import ReActAction
|
||||||
|
from dbgpt.core import ModelMessageRoleType
|
||||||
|
from dbgpt.util.configure import DynConfig
|
||||||
|
from dbgpt.util.json_utils import find_json_objects
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_REACT_SYSTEM_TEMPLATE = """\
|
||||||
|
You are a {{ role }}, {% if name %}named {{ name }}.
|
||||||
|
{% endif %}your goal is {% if is_retry_chat %}{{ retry_goal }}
|
||||||
|
{% else %}{{ goal }}
|
||||||
|
{% endif %}.\
|
||||||
|
At the same time, please strictly abide by the constraints and specifications
|
||||||
|
in the "IMPORTANT REMINDER" below.
|
||||||
|
{% if resource_prompt %}\
|
||||||
|
# ACTION SPACE #
|
||||||
|
{{ resource_prompt }}
|
||||||
|
{% endif %}
|
||||||
|
{% if expand_prompt %}\
|
||||||
|
{{ expand_prompt }}
|
||||||
|
{% endif %}\
|
||||||
|
|
||||||
|
|
||||||
|
# IMPORTANT REMINDER #
|
||||||
|
The current time is:{{now_time}}.
|
||||||
|
{% if constraints %}\
|
||||||
|
{% for constraint in constraints %}\
|
||||||
|
{{ loop.index }}. {{ constraint }}
|
||||||
|
{% endfor %}\
|
||||||
|
{% endif %}\
|
||||||
|
|
||||||
|
|
||||||
|
{% if is_retry_chat %}\
|
||||||
|
{% if retry_constraints %}\
|
||||||
|
{% for retry_constraint in retry_constraints %}\
|
||||||
|
{{ loop.index }}. {{ retry_constraint }}
|
||||||
|
{% endfor %}\
|
||||||
|
{% endif %}\
|
||||||
|
{% else %}\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}\
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
{% if examples %}\
|
||||||
|
# EXAMPLE INTERACTION #
|
||||||
|
You can refer to the following examples:
|
||||||
|
{{ examples }}\
|
||||||
|
{% endif %}\
|
||||||
|
|
||||||
|
{% if most_recent_memories %}\
|
||||||
|
# History of Solving Task#
|
||||||
|
{{ most_recent_memories }}\
|
||||||
|
{% endif %}\
|
||||||
|
|
||||||
|
# RESPONSE FORMAT #
|
||||||
|
{% if out_schema %} {{ out_schema }} {% endif %}\
|
||||||
|
|
||||||
|
################### TASK ###################
|
||||||
|
Please solve the task:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_REACT_WRITE_MEMORY_TEMPLATE = """\
|
||||||
|
{% if question %}Question: {{ question }} {% endif %}
|
||||||
|
{% if assistant %}Assistant: {{ assistant }} {% endif %}
|
||||||
|
{% if observation %}Observation: {{ observation }} {% endif %}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReActAgent(ConversableAgent):
|
||||||
|
end_action_name: str = DynConfig(
|
||||||
|
"terminate",
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_end_action_name",
|
||||||
|
)
|
||||||
|
max_steps: int = DynConfig(
|
||||||
|
10,
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_max_steps",
|
||||||
|
)
|
||||||
|
profile: ProfileConfig = ProfileConfig(
|
||||||
|
name=DynConfig(
|
||||||
|
"ReAct",
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_name",
|
||||||
|
),
|
||||||
|
role=DynConfig(
|
||||||
|
"ToolMaster",
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_role",
|
||||||
|
),
|
||||||
|
goal=DynConfig(
|
||||||
|
"Read and understand the tool information given in the action space "
|
||||||
|
"below to understand their capabilities and how to use them,and choosing "
|
||||||
|
"the right tool to solve the task",
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_goal",
|
||||||
|
),
|
||||||
|
constraints=DynConfig(
|
||||||
|
[
|
||||||
|
"Achieve the goal step by step."
|
||||||
|
"Each step, please read the parameter definition of the tool carefully "
|
||||||
|
"and extract the specific parameters required to execute the tool "
|
||||||
|
"from the user's goal.",
|
||||||
|
"information in json format according to the following required format."
|
||||||
|
"If there is an example, please refer to the sample format output.",
|
||||||
|
"each step, you can only select one tool in action space.",
|
||||||
|
],
|
||||||
|
category="agent",
|
||||||
|
key="dbgpt_agent_expand_plugin_assistant_agent_constraints",
|
||||||
|
),
|
||||||
|
system_prompt_template=_REACT_SYSTEM_TEMPLATE,
|
||||||
|
write_memory_template=_REACT_WRITE_MEMORY_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._init_actions([ReActAction])
|
||||||
|
|
||||||
|
async def review(self, message: Optional[str], censored: Agent) -> Tuple[bool, Any]:
|
||||||
|
"""Review the message based on the censored message."""
|
||||||
|
try:
|
||||||
|
json_obj = find_json_objects(message)
|
||||||
|
if len(json_obj) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"No correct json object found in the message。"
|
||||||
|
"Please strictly output JSON in the defined "
|
||||||
|
"format, and only one action can be ouput each time. "
|
||||||
|
)
|
||||||
|
return True, json_obj[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"review error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def validate_action(self, action_name: str) -> bool:
|
||||||
|
tools = self.resource.get_resource_by_type(ResourceType.Tool)
|
||||||
|
for tool in tools:
|
||||||
|
if tool.name == action_name:
|
||||||
|
return True
|
||||||
|
raise ValueError(f"{action_name} is not in the action space.")
|
||||||
|
|
||||||
|
async def generate_reply(
|
||||||
|
self,
|
||||||
|
received_message: AgentMessage,
|
||||||
|
sender: Agent,
|
||||||
|
reviewer: Optional[Agent] = None,
|
||||||
|
rely_messages: Optional[List[AgentMessage]] = None,
|
||||||
|
historical_dialogues: Optional[List[AgentMessage]] = None,
|
||||||
|
is_retry_chat: bool = False,
|
||||||
|
last_speaker_name: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AgentMessage:
|
||||||
|
"""Generate a reply based on the received messages."""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"generate agent reply!sender={sender}, "
|
||||||
|
f"rely_messages_len={rely_messages}"
|
||||||
|
)
|
||||||
|
self.validate_action(self.end_action_name)
|
||||||
|
observation = AgentMessage(content="please start!")
|
||||||
|
reply_message: AgentMessage = self._init_reply_message(
|
||||||
|
received_message=received_message
|
||||||
|
)
|
||||||
|
thinking_messages, resource_info = await self._load_thinking_messages(
|
||||||
|
received_message=observation,
|
||||||
|
sender=sender,
|
||||||
|
rely_messages=rely_messages,
|
||||||
|
historical_dialogues=historical_dialogues,
|
||||||
|
context=reply_message.get_dict_context(),
|
||||||
|
is_retry_chat=is_retry_chat,
|
||||||
|
)
|
||||||
|
# attach current task to system prompt
|
||||||
|
thinking_messages[0].content = (
|
||||||
|
thinking_messages[0].content + "\n" + received_message.content
|
||||||
|
)
|
||||||
|
done = False
|
||||||
|
max_steps = self.max_steps
|
||||||
|
await self.write_memories(
|
||||||
|
question=received_message.content,
|
||||||
|
ai_message="",
|
||||||
|
)
|
||||||
|
while not done and max_steps > 0:
|
||||||
|
ai_message = ""
|
||||||
|
try:
|
||||||
|
# 1. thinking
|
||||||
|
llm_reply, model_name = await self.thinking(
|
||||||
|
thinking_messages, sender
|
||||||
|
)
|
||||||
|
reply_message.model_name = model_name
|
||||||
|
reply_message.resource_info = resource_info
|
||||||
|
ai_message = llm_reply
|
||||||
|
thinking_messages.append(
|
||||||
|
AgentMessage(role=ModelMessageRoleType.AI, content=llm_reply)
|
||||||
|
)
|
||||||
|
approve, json_obj = await self.review(llm_reply, self)
|
||||||
|
logger.info(f"jons_obj: {json_obj}")
|
||||||
|
action = json_obj["Action"]
|
||||||
|
thought = json_obj["Thought"]
|
||||||
|
action.update({"thought": thought})
|
||||||
|
reply_message.content = json.dumps(action, ensure_ascii=False)
|
||||||
|
tool_name = action["tool_name"]
|
||||||
|
self.validate_action(tool_name)
|
||||||
|
# 2. act
|
||||||
|
act_extent_param = self.prepare_act_param(
|
||||||
|
received_message=received_message,
|
||||||
|
sender=sender,
|
||||||
|
rely_messages=rely_messages,
|
||||||
|
historical_dialogues=historical_dialogues,
|
||||||
|
)
|
||||||
|
act_out: ActionOutput = await self.act(
|
||||||
|
message=reply_message,
|
||||||
|
sender=sender,
|
||||||
|
reviewer=reviewer,
|
||||||
|
is_retry_chat=is_retry_chat,
|
||||||
|
last_speaker_name=last_speaker_name,
|
||||||
|
**act_extent_param,
|
||||||
|
)
|
||||||
|
if act_out:
|
||||||
|
reply_message.action_report = act_out
|
||||||
|
|
||||||
|
# 3. obs
|
||||||
|
check_pass, reason = await self.verify(
|
||||||
|
reply_message, sender, reviewer
|
||||||
|
)
|
||||||
|
done = tool_name == self.end_action_name and check_pass
|
||||||
|
if check_pass:
|
||||||
|
logger.info(f"Observation:{act_out.content}")
|
||||||
|
thinking_messages.append(
|
||||||
|
AgentMessage(
|
||||||
|
role=ModelMessageRoleType.HUMAN,
|
||||||
|
content=f"Observation: {tool_name} "
|
||||||
|
f"output:{act_out.content}\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.write_memories(
|
||||||
|
question="",
|
||||||
|
ai_message=ai_message,
|
||||||
|
action_output=act_out,
|
||||||
|
check_pass=check_pass,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
observation = f"Observation: {reason}"
|
||||||
|
logger.info(f"Observation:{observation}")
|
||||||
|
thinking_messages.append(
|
||||||
|
AgentMessage(
|
||||||
|
role=ModelMessageRoleType.HUMAN, content=observation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.write_memories(
|
||||||
|
question="",
|
||||||
|
ai_message=ai_message,
|
||||||
|
check_pass=check_pass,
|
||||||
|
check_fail_reason=reason,
|
||||||
|
)
|
||||||
|
max_steps -= 1
|
||||||
|
except Exception as e:
|
||||||
|
fail_reason = (
|
||||||
|
f"Observation: Exception occurs:({type(e).__name__}){e}."
|
||||||
|
)
|
||||||
|
logger.error(fail_reason)
|
||||||
|
thinking_messages.append(
|
||||||
|
AgentMessage(
|
||||||
|
role=ModelMessageRoleType.HUMAN, content=fail_reason
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.write_memories(
|
||||||
|
question="",
|
||||||
|
ai_message=ai_message,
|
||||||
|
check_pass=False,
|
||||||
|
check_fail_reason=fail_reason,
|
||||||
|
)
|
||||||
|
reply_message.success = done
|
||||||
|
await self.adjust_final_message(True, reply_message)
|
||||||
|
return reply_message
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Generate reply exception!")
|
||||||
|
err_message = AgentMessage(content=str(e))
|
||||||
|
err_message.success = False
|
||||||
|
return err_message
|
||||||
|
|
||||||
|
async def write_memories(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
ai_message: str,
|
||||||
|
action_output: Optional[ActionOutput] = None,
|
||||||
|
check_pass: bool = True,
|
||||||
|
check_fail_reason: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Write the memories to the memory.
|
||||||
|
|
||||||
|
We suggest you to override this method to save the conversation to memory
|
||||||
|
according to your needs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question(str): The question received.
|
||||||
|
ai_message(str): The AI message, LLM output.
|
||||||
|
action_output(ActionOutput): The action output.
|
||||||
|
check_pass(bool): Whether the check pass.
|
||||||
|
check_fail_reason(str): The check fail reason.
|
||||||
|
"""
|
||||||
|
observation = ""
|
||||||
|
if action_output and action_output.observations:
|
||||||
|
observation = action_output.observations
|
||||||
|
elif check_fail_reason:
|
||||||
|
observation = check_fail_reason
|
||||||
|
memory_map = {
|
||||||
|
"question": question,
|
||||||
|
"assistant": ai_message,
|
||||||
|
"observation": observation,
|
||||||
|
}
|
||||||
|
write_memory_template = self.write_memory_template
|
||||||
|
memory_content = self._render_template(write_memory_template, **memory_map)
|
||||||
|
fragment = AgentMemoryFragment(memory_content)
|
||||||
|
await self.memory.write(fragment)
|
Loading…
Reference in New Issue
Block a user