mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 00:03:29 +00:00
feat(agents): add ReActAgent (#2420)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
This commit is contained in:
parent
a3216a7994
commit
81f4c6a558
89
examples/agents/react_agent_example.py
Normal file
89
examples/agents/react_agent_example.py
Normal file
@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
||||
from dbgpt.agent.expand.react_agent import ReActAgent
|
||||
from dbgpt.agent.resource import ToolPack, tool
|
||||
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def terminate(
|
||||
final_answer: Annotated[str, Doc("final literal answer about the goal")],
|
||||
) -> str:
|
||||
"""When the goal achieved, this tool must be called."""
|
||||
return final_answer
|
||||
|
||||
|
||||
@tool
|
||||
def simple_calculator(first_number: int, second_number: int, operator: str) -> float:
|
||||
"""Simple calculator tool. Just support +, -, *, /."""
|
||||
if isinstance(first_number, str):
|
||||
first_number = int(first_number)
|
||||
if isinstance(second_number, str):
|
||||
second_number = int(second_number)
|
||||
if operator == "+":
|
||||
return first_number + second_number
|
||||
elif operator == "-":
|
||||
return first_number - second_number
|
||||
elif operator == "*":
|
||||
return first_number * second_number
|
||||
elif operator == "/":
|
||||
return first_number / second_number
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
|
||||
@tool
|
||||
def count_directory_files(path: Annotated[str, Doc("The directory path")]) -> int:
|
||||
"""Count the number of files in a directory."""
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError(f"Invalid directory path: {path}")
|
||||
return len(os.listdir(path))
|
||||
|
||||
|
||||
async def main():
|
||||
from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient
|
||||
|
||||
llm_client = SiliconFlowLLMClient(
|
||||
model_alias="Qwen/Qwen2-7B-Instruct",
|
||||
)
|
||||
agent_memory = AgentMemory()
|
||||
agent_memory.gpts_memory.init(conv_id="test456")
|
||||
|
||||
context: AgentContext = AgentContext(conv_id="test456", gpts_app_name="ReAct")
|
||||
|
||||
tools = ToolPack([simple_calculator, count_directory_files, terminate])
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
|
||||
|
||||
tool_engineer = (
|
||||
await ReActAgent(end_action_name="terminate", max_steps=10)
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(agent_memory)
|
||||
.bind(tools)
|
||||
.build()
|
||||
)
|
||||
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=tool_engineer,
|
||||
reviewer=user_proxy,
|
||||
message="Calculate the product of 10 and 99, Count the number of files in /tmp, answer in Chinese.",
|
||||
)
|
||||
|
||||
# dbgpt-vis message infos
|
||||
print(await agent_memory.gpts_memory.app_link_chat_message("test456"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -0,0 +1,54 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.agent import ResourceType
|
||||
from dbgpt.agent.expand.actions.tool_action import ToolAction, ToolInput
|
||||
from dbgpt.vis import Vis, VisPlugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReActAction(ToolAction):
|
||||
"""ReAct action class."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Tool action init."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisPlugin()
|
||||
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
"""Return the resource type needed for the action."""
|
||||
return ResourceType.Tool
|
||||
|
||||
@property
|
||||
def render_protocol(self) -> Optional[Vis]:
|
||||
"""Return the render protocol."""
|
||||
return self._render_protocol
|
||||
|
||||
@property
|
||||
def out_model_type(self):
|
||||
"""Return the output model type."""
|
||||
return ToolInput
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
out_put_schema = {
|
||||
"Thought": "Summary of thoughts to the user",
|
||||
"Action": {
|
||||
"tool_name": "The name of a tool that can be used to answer "
|
||||
"the current"
|
||||
"question or solve the current task.",
|
||||
"args": {
|
||||
"arg name1": "arg value1",
|
||||
"arg name2": "arg value2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return f"""Please response in the following json format:
|
||||
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
|
||||
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
329
packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py
Normal file
329
packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py
Normal file
@ -0,0 +1,329 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from dbgpt.agent import (
|
||||
ActionOutput,
|
||||
Agent,
|
||||
AgentMemoryFragment,
|
||||
AgentMessage,
|
||||
ConversableAgent,
|
||||
ProfileConfig,
|
||||
ResourceType,
|
||||
)
|
||||
from dbgpt.agent.expand.actions.react_action import ReActAction
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
from dbgpt.util.configure import DynConfig
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_REACT_SYSTEM_TEMPLATE = """\
|
||||
You are a {{ role }}, {% if name %}named {{ name }}.
|
||||
{% endif %}your goal is {% if is_retry_chat %}{{ retry_goal }}
|
||||
{% else %}{{ goal }}
|
||||
{% endif %}.\
|
||||
At the same time, please strictly abide by the constraints and specifications
|
||||
in the "IMPORTANT REMINDER" below.
|
||||
{% if resource_prompt %}\
|
||||
# ACTION SPACE #
|
||||
{{ resource_prompt }}
|
||||
{% endif %}
|
||||
{% if expand_prompt %}\
|
||||
{{ expand_prompt }}
|
||||
{% endif %}\
|
||||
|
||||
|
||||
# IMPORTANT REMINDER #
|
||||
The current time is:{{now_time}}.
|
||||
{% if constraints %}\
|
||||
{% for constraint in constraints %}\
|
||||
{{ loop.index }}. {{ constraint }}
|
||||
{% endfor %}\
|
||||
{% endif %}\
|
||||
|
||||
|
||||
{% if is_retry_chat %}\
|
||||
{% if retry_constraints %}\
|
||||
{% for retry_constraint in retry_constraints %}\
|
||||
{{ loop.index }}. {{ retry_constraint }}
|
||||
{% endfor %}\
|
||||
{% endif %}\
|
||||
{% else %}\
|
||||
|
||||
|
||||
|
||||
{% endif %}\
|
||||
|
||||
|
||||
|
||||
{% if examples %}\
|
||||
# EXAMPLE INTERACTION #
|
||||
You can refer to the following examples:
|
||||
{{ examples }}\
|
||||
{% endif %}\
|
||||
|
||||
{% if most_recent_memories %}\
|
||||
# History of Solving Task#
|
||||
{{ most_recent_memories }}\
|
||||
{% endif %}\
|
||||
|
||||
# RESPONSE FORMAT #
|
||||
{% if out_schema %} {{ out_schema }} {% endif %}\
|
||||
|
||||
################### TASK ###################
|
||||
Please solve the task:
|
||||
"""
|
||||
|
||||
|
||||
_REACT_WRITE_MEMORY_TEMPLATE = """\
|
||||
{% if question %}Question: {{ question }} {% endif %}
|
||||
{% if assistant %}Assistant: {{ assistant }} {% endif %}
|
||||
{% if observation %}Observation: {{ observation }} {% endif %}
|
||||
"""
|
||||
|
||||
|
||||
class ReActAgent(ConversableAgent):
|
||||
end_action_name: str = DynConfig(
|
||||
"terminate",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_end_action_name",
|
||||
)
|
||||
max_steps: int = DynConfig(
|
||||
10,
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_max_steps",
|
||||
)
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"ReAct",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"ToolMaster",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Read and understand the tool information given in the action space "
|
||||
"below to understand their capabilities and how to use them,and choosing "
|
||||
"the right tool to solve the task",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Achieve the goal step by step."
|
||||
"Each step, please read the parameter definition of the tool carefully "
|
||||
"and extract the specific parameters required to execute the tool "
|
||||
"from the user's goal.",
|
||||
"information in json format according to the following required format."
|
||||
"If there is an example, please refer to the sample format output.",
|
||||
"each step, you can only select one tool in action space.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_constraints",
|
||||
),
|
||||
system_prompt_template=_REACT_SYSTEM_TEMPLATE,
|
||||
write_memory_template=_REACT_WRITE_MEMORY_TEMPLATE,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([ReActAction])
|
||||
|
||||
async def review(self, message: Optional[str], censored: Agent) -> Tuple[bool, Any]:
|
||||
"""Review the message based on the censored message."""
|
||||
try:
|
||||
json_obj = find_json_objects(message)
|
||||
if len(json_obj) == 0:
|
||||
raise ValueError(
|
||||
"No correct json object found in the message。"
|
||||
"Please strictly output JSON in the defined "
|
||||
"format, and only one action can be ouput each time. "
|
||||
)
|
||||
return True, json_obj[0]
|
||||
except Exception as e:
|
||||
logger.error(f"review error: {e}")
|
||||
raise e
|
||||
|
||||
def validate_action(self, action_name: str) -> bool:
|
||||
tools = self.resource.get_resource_by_type(ResourceType.Tool)
|
||||
for tool in tools:
|
||||
if tool.name == action_name:
|
||||
return True
|
||||
raise ValueError(f"{action_name} is not in the action space.")
|
||||
|
||||
async def generate_reply(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
reviewer: Optional[Agent] = None,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
historical_dialogues: Optional[List[AgentMessage]] = None,
|
||||
is_retry_chat: bool = False,
|
||||
last_speaker_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> AgentMessage:
|
||||
"""Generate a reply based on the received messages."""
|
||||
try:
|
||||
logger.info(
|
||||
f"generate agent reply!sender={sender}, "
|
||||
f"rely_messages_len={rely_messages}"
|
||||
)
|
||||
self.validate_action(self.end_action_name)
|
||||
observation = AgentMessage(content="please start!")
|
||||
reply_message: AgentMessage = self._init_reply_message(
|
||||
received_message=received_message
|
||||
)
|
||||
thinking_messages, resource_info = await self._load_thinking_messages(
|
||||
received_message=observation,
|
||||
sender=sender,
|
||||
rely_messages=rely_messages,
|
||||
historical_dialogues=historical_dialogues,
|
||||
context=reply_message.get_dict_context(),
|
||||
is_retry_chat=is_retry_chat,
|
||||
)
|
||||
# attach current task to system prompt
|
||||
thinking_messages[0].content = (
|
||||
thinking_messages[0].content + "\n" + received_message.content
|
||||
)
|
||||
done = False
|
||||
max_steps = self.max_steps
|
||||
await self.write_memories(
|
||||
question=received_message.content,
|
||||
ai_message="",
|
||||
)
|
||||
while not done and max_steps > 0:
|
||||
ai_message = ""
|
||||
try:
|
||||
# 1. thinking
|
||||
llm_reply, model_name = await self.thinking(
|
||||
thinking_messages, sender
|
||||
)
|
||||
reply_message.model_name = model_name
|
||||
reply_message.resource_info = resource_info
|
||||
ai_message = llm_reply
|
||||
thinking_messages.append(
|
||||
AgentMessage(role=ModelMessageRoleType.AI, content=llm_reply)
|
||||
)
|
||||
approve, json_obj = await self.review(llm_reply, self)
|
||||
logger.info(f"jons_obj: {json_obj}")
|
||||
action = json_obj["Action"]
|
||||
thought = json_obj["Thought"]
|
||||
action.update({"thought": thought})
|
||||
reply_message.content = json.dumps(action, ensure_ascii=False)
|
||||
tool_name = action["tool_name"]
|
||||
self.validate_action(tool_name)
|
||||
# 2. act
|
||||
act_extent_param = self.prepare_act_param(
|
||||
received_message=received_message,
|
||||
sender=sender,
|
||||
rely_messages=rely_messages,
|
||||
historical_dialogues=historical_dialogues,
|
||||
)
|
||||
act_out: ActionOutput = await self.act(
|
||||
message=reply_message,
|
||||
sender=sender,
|
||||
reviewer=reviewer,
|
||||
is_retry_chat=is_retry_chat,
|
||||
last_speaker_name=last_speaker_name,
|
||||
**act_extent_param,
|
||||
)
|
||||
if act_out:
|
||||
reply_message.action_report = act_out
|
||||
|
||||
# 3. obs
|
||||
check_pass, reason = await self.verify(
|
||||
reply_message, sender, reviewer
|
||||
)
|
||||
done = tool_name == self.end_action_name and check_pass
|
||||
if check_pass:
|
||||
logger.info(f"Observation:{act_out.content}")
|
||||
thinking_messages.append(
|
||||
AgentMessage(
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
content=f"Observation: {tool_name} "
|
||||
f"output:{act_out.content}\n",
|
||||
)
|
||||
)
|
||||
await self.write_memories(
|
||||
question="",
|
||||
ai_message=ai_message,
|
||||
action_output=act_out,
|
||||
check_pass=check_pass,
|
||||
)
|
||||
else:
|
||||
observation = f"Observation: {reason}"
|
||||
logger.info(f"Observation:{observation}")
|
||||
thinking_messages.append(
|
||||
AgentMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content=observation
|
||||
)
|
||||
)
|
||||
await self.write_memories(
|
||||
question="",
|
||||
ai_message=ai_message,
|
||||
check_pass=check_pass,
|
||||
check_fail_reason=reason,
|
||||
)
|
||||
max_steps -= 1
|
||||
except Exception as e:
|
||||
fail_reason = (
|
||||
f"Observation: Exception occurs:({type(e).__name__}){e}."
|
||||
)
|
||||
logger.error(fail_reason)
|
||||
thinking_messages.append(
|
||||
AgentMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content=fail_reason
|
||||
)
|
||||
)
|
||||
await self.write_memories(
|
||||
question="",
|
||||
ai_message=ai_message,
|
||||
check_pass=False,
|
||||
check_fail_reason=fail_reason,
|
||||
)
|
||||
reply_message.success = done
|
||||
await self.adjust_final_message(True, reply_message)
|
||||
return reply_message
|
||||
except Exception as e:
|
||||
logger.exception("Generate reply exception!")
|
||||
err_message = AgentMessage(content=str(e))
|
||||
err_message.success = False
|
||||
return err_message
|
||||
|
||||
async def write_memories(
|
||||
self,
|
||||
question: str,
|
||||
ai_message: str,
|
||||
action_output: Optional[ActionOutput] = None,
|
||||
check_pass: bool = True,
|
||||
check_fail_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Write the memories to the memory.
|
||||
|
||||
We suggest you to override this method to save the conversation to memory
|
||||
according to your needs.
|
||||
|
||||
Args:
|
||||
question(str): The question received.
|
||||
ai_message(str): The AI message, LLM output.
|
||||
action_output(ActionOutput): The action output.
|
||||
check_pass(bool): Whether the check pass.
|
||||
check_fail_reason(str): The check fail reason.
|
||||
"""
|
||||
observation = ""
|
||||
if action_output and action_output.observations:
|
||||
observation = action_output.observations
|
||||
elif check_fail_reason:
|
||||
observation = check_fail_reason
|
||||
memory_map = {
|
||||
"question": question,
|
||||
"assistant": ai_message,
|
||||
"observation": observation,
|
||||
}
|
||||
write_memory_template = self.write_memory_template
|
||||
memory_content = self._render_template(write_memory_template, **memory_map)
|
||||
fragment = AgentMemoryFragment(memory_content)
|
||||
await self.memory.write(fragment)
|
Loading…
Reference in New Issue
Block a user