fix(agent): Fix agent loss message bug (#1283)

This commit is contained in:
明天
2024-03-14 14:38:10 +08:00
committed by GitHub
parent adaa68eb00
commit a207640ff2
10 changed files with 276 additions and 89 deletions

View File

@@ -83,10 +83,10 @@ class PluginAction(Action[PluginInput]):
if not resource_plugin_client:
raise ValueError("No implementation of the use of plug-in resources")
response_success = True
status = Status.TODO.value
status = Status.RUNNING.value
tool_result = ""
err_msg = None
try:
status = Status.RUNNING.value
tool_result = await resource_plugin_client.a_execute_command(
param.tool_name, param.args, plugin_generator
)

View File

@@ -4,6 +4,12 @@ import dataclasses
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from dbgpt.agent.resource.resource_loader import ResourceLoader
from dbgpt.core import LLMClient
from dbgpt.util.annotations import PublicAPI
from ..memory.gpts_memory import GptsMemory
class Agent(ABC):
async def a_send(
@@ -72,6 +78,8 @@ class Agent(ABC):
async def a_act(
self,
message: Optional[str],
sender: Optional[Agent] = None,
reviewer: Optional[Agent] = None,
**kwargs,
) -> Union[str, Dict, None]:
"""
@@ -101,3 +109,42 @@ class Agent(ABC):
Returns:
"""
@dataclasses.dataclass
class AgentContext:
conv_id: str
gpts_app_name: str = None
language: str = None
max_chat_round: Optional[int] = 100
max_retry_round: Optional[int] = 10
max_new_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.5
allow_format_str_template: Optional[bool] = False
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclasses.dataclass
@PublicAPI(stability="beta")
class AgentGenerateContext:
"""A class to represent the input of a Agent."""
message: Optional[Dict]
sender: Agent
reviewer: Agent
silent: Optional[bool] = False
rely_messages: List[Dict] = dataclasses.field(default_factory=list)
final: Optional[bool] = True
memory: Optional[GptsMemory] = None
agent_context: Optional[AgentContext] = None
resource_loader: Optional[ResourceLoader] = None
llm_client: Optional[LLMClient] = None
round_index: int = None
def to_dict(self) -> Dict:
return dataclasses.asdict(self)

View File

@@ -8,8 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, Field
from dbgpt.agent.actions.action import Action, ActionOutput
from dbgpt.agent.agents.agent import AgentContext
from dbgpt.agent.agents.agent_new import Agent
from dbgpt.agent.agents.agent_new import Agent, AgentContext
from dbgpt.agent.agents.llm.llm import LLMConfig, LLMStrategyType
from dbgpt.agent.agents.llm.llm_client import AIWrapper
from dbgpt.agent.agents.role import Role
@@ -31,7 +30,7 @@ class ConversableAgent(Role, Agent):
llm_config: Optional[LLMConfig] = None
memory: GptsMemory = Field(default_factory=GptsMemory)
resource_loader: Optional[ResourceLoader] = None
max_retry_count: int = 10
max_retry_count: int = 3
consecutive_auto_reply_counter: int = 0
llm_client: Optional[AIWrapper] = None
oai_system_message: List[Dict] = Field(default_factory=list)
@@ -178,54 +177,75 @@ class ConversableAgent(Role, Agent):
logger.info(
f"generate agent reply!sender={sender}, rely_messages_len={rely_messages}"
)
reply_message = self._init_reply_message(recive_message=recive_message)
await self._system_message_assembly(
recive_message["content"], reply_message.get("context", None)
)
fail_reason = None
current_retry_counter = 0
is_sucess = True
while current_retry_counter < self.max_retry_count:
if current_retry_counter > 0:
retry_message = self._init_reply_message(recive_message=recive_message)
retry_message["content"] = fail_reason
# The current message is a self-optimized message that needs to be recorded.
# It is temporarily set to be initiated by the originating end to facilitate the organization of historical memory context.
await sender.a_send(retry_message, self, reviewer, request_reply=False)
# 1.Think about how to do things
llm_reply, model_name = await self.a_thinking(
self._load_thinking_messages(recive_message, sender, rely_messages)
try:
reply_message = self._init_reply_message(recive_message=recive_message)
await self._system_message_assembly(
recive_message["content"], reply_message.get("context", None)
)
reply_message["model_name"] = model_name
reply_message["content"] = llm_reply
# 2.Review whether what is being done is legal
approve, comments = await self.a_review(llm_reply, self)
reply_message["review_info"] = {"approve": approve, "comments": comments}
fail_reason = None
current_retry_counter = 0
is_sucess = True
while current_retry_counter < self.max_retry_count:
if current_retry_counter > 0:
retry_message = self._init_reply_message(
recive_message=recive_message
)
retry_message["content"] = fail_reason
retry_message["current_goal"] = recive_message.get(
"current_goal", None
)
# The current message is a self-optimized message that needs to be recorded.
# It is temporarily set to be initiated by the originating end to facilitate the organization of historical memory context.
await sender.a_send(
retry_message, self, reviewer, request_reply=False
)
# 3.Act based on the results of your thinking
act_extent_param = self.prepare_act_param()
act_out: ActionOutput = await self.a_act(
message=llm_reply,
**act_extent_param,
)
reply_message["action_report"] = act_out.dict()
# 1.Think about how to do things
llm_reply, model_name = await self.a_thinking(
self._load_thinking_messages(recive_message, sender, rely_messages)
)
reply_message["model_name"] = model_name
reply_message["content"] = llm_reply
# 4.Reply information verification
check_paas, reason = await self.a_verify(reply_message, sender, reviewer)
is_sucess = check_paas
# 5.Optimize wrong answers myself
if not check_paas:
current_retry_counter += 1
# Send error messages and issue new problem-solving instructions
await self.a_send(reply_message, sender, reviewer, request_reply=False)
fail_reason = reason
else:
break
return is_sucess, reply_message
# 2.Review whether what is being done is legal
approve, comments = await self.a_review(llm_reply, self)
reply_message["review_info"] = {
"approve": approve,
"comments": comments,
}
# 3.Act based on the results of your thinking
act_extent_param = self.prepare_act_param()
act_out: ActionOutput = await self.a_act(
message=llm_reply,
sender=sender,
reviewer=reviewer,
**act_extent_param,
)
reply_message["action_report"] = act_out.dict()
# 4.Reply information verification
check_paas, reason = await self.a_verify(
reply_message, sender, reviewer
)
is_sucess = check_paas
# 5.Optimize wrong answers myself
if not check_paas:
current_retry_counter += 1
# Send error messages and issue new problem-solving instructions
if current_retry_counter < self.max_retry_count:
await self.a_send(
reply_message, sender, reviewer, request_reply=False
)
fail_reason = reason
else:
break
return is_sucess, reply_message
except Exception as e:
logger.exception("Generate reply exception!")
return False, {"content": str(e)}
async def a_thinking(
self, messages: Optional[List[Dict]], prompt: Optional[str] = None
@@ -265,7 +285,13 @@ class ConversableAgent(Role, Agent):
) -> Tuple[bool, Any]:
return True, None
async def a_act(self, message: Optional[str], **kwargs) -> Optional[ActionOutput]:
async def a_act(
self,
message: Optional[str],
sender: Optional[ConversableAgent] = None,
reviewer: Optional[ConversableAgent] = None,
**kwargs,
) -> Optional[ActionOutput]:
last_out = None
for action in self.actions:
# Select the resources required by acton
@@ -335,6 +361,7 @@ class ConversableAgent(Role, Agent):
#######################################################################
def _init_actions(self, actions: List[Action] = None):
self.actions = []
for idx, action in enumerate(actions):
if not isinstance(action, Action):
self.actions.append(action())
@@ -426,7 +453,9 @@ class ConversableAgent(Role, Agent):
for item in self.resources:
resource_client = self.resource_loader.get_resesource_api(item.type)
resource_prompt_list.append(
await resource_client.get_resource_prompt(item, qustion)
await resource_client.get_resource_prompt(
self.agent_context.conv_id, item, qustion
)
)
if context is None:
context = {}
@@ -525,7 +554,11 @@ class ConversableAgent(Role, Agent):
content = item.content
if item.action_report:
action_out = ActionOutput.from_dict(json.loads(item.action_report))
if action_out is not None and action_out.content is not None:
if (
action_out is not None
and action_out.is_exe_success
and action_out.content is not None
):
content = action_out.content
oai_messages.append(
{

View File

@@ -1,15 +1,18 @@
from __future__ import annotations
import json
from collections import defaultdict
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional
from dbgpt.agent.actions.action import ActionOutput
from dbgpt.util.json_utils import EnhancedJSONEncoder
from dbgpt.vis.client import VisAgentMessages, VisAgentPlans, vis_client
from .base import GptsMessage, GptsMessageMemory, GptsPlansMemory
from .default_gpts_memory import DefaultGptsMessageMemory, DefaultGptsPlansMemory
NONE_GOAL_PREFIX: str = "none_goal_count_"
class GptsMemory:
def __init__(
@@ -32,6 +35,41 @@ class GptsMemory:
def message_memory(self):
return self._message_memory
async def _message_group_vis_build(self, message_group):
if not message_group:
return ""
num: int = 0
last_goal = next(reversed(message_group))
last_goal_messages = message_group[last_goal]
last_goal_message = last_goal_messages[-1]
vis_items = []
plan_temps = []
for key, value in message_group.items():
num = num + 1
if key.startswith(NONE_GOAL_PREFIX):
vis_items.append(await self._messages_to_plan_vis(plan_temps))
plan_temps = []
num = 0
vis_items.append(await self._messages_to_agents_vis(value))
else:
num += 1
plan_temps.append(
{
"name": key,
"num": num,
"status": "complete",
"agent": value[0].receiver if value else "",
"markdown": await self._messages_to_agents_vis(value),
}
)
if len(plan_temps) > 0:
vis_items.append(await self._messages_to_plan_vis(plan_temps))
vis_items.append(await self._messages_to_agents_vis([last_goal_message]))
return "\n".join(vis_items)
async def _plan_vis_build(self, plan_group: dict[str, list]):
num: int = 0
plan_items = []
@@ -48,6 +86,37 @@ class GptsMemory:
)
return await self._messages_to_plan_vis(plan_items)
async def one_chat_competions_v2(self, conv_id: str):
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
temp_group = OrderedDict()
none_goal_count = 1
count: int = 0
for message in messages:
count = count + 1
if count == 1:
continue
current_gogal = message.current_goal
last_goal = next(reversed(temp_group)) if temp_group else None
if last_goal:
last_goal_messages = temp_group[last_goal]
if current_gogal:
if current_gogal == last_goal:
last_goal_messages.append(message)
else:
temp_group[current_gogal] = [message]
else:
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
none_goal_count += 1
else:
if current_gogal:
temp_group[current_gogal] = [message]
else:
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
none_goal_count += 1
return await self._message_group_vis_build(temp_group)
async def one_chat_competions(self, conv_id: str):
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
temp_group = defaultdict(list)
@@ -76,12 +145,14 @@ class GptsMemory:
vis_items.append(await self._plan_vis_build(temp_group))
temp_group.clear()
if len(temp_messages) > 0:
vis_items.append(await self._messages_to_agents_vis(temp_messages))
vis_items.append(await self._messages_to_agents_vis(temp_messages, True))
temp_messages.clear()
return "\n".join(vis_items)
async def _messages_to_agents_vis(self, messages: List[GptsMessage]):
async def _messages_to_agents_vis(
self, messages: List[GptsMessage], is_last_message: bool = False
):
if messages is None or len(messages) <= 0:
return ""
messages_view = []
@@ -89,10 +160,11 @@ class GptsMemory:
action_report_str = message.action_report
view_info = message.content
if action_report_str and len(action_report_str) > 0:
action_report = json.loads(action_report_str)
if action_report:
view = action_report.get("view", None)
view_info = view if view else action_report.get("content", "")
action_out = ActionOutput.from_dict(json.loads(action_report_str))
if action_out is not None:
if action_out.is_exe_success or is_last_message:
view = action_out.view
view_info = view if view else action_out.content
messages_view.append(
{
@@ -102,9 +174,8 @@ class GptsMemory:
"markdown": view_info,
}
)
return await vis_client.get(VisAgentMessages.vis_tag()).display(
content=messages_view
)
vis_compent = vis_client.get(VisAgentMessages.vis_tag())
return await vis_compent.display(content=messages_view)
async def _messages_to_plan_vis(self, messages: List[Dict]):
if messages is None or len(messages) <= 0:

View File

@@ -14,7 +14,10 @@ class ResourceType(Enum):
Knowledge = "knowledge"
Internet = "internet"
Plugin = "plugin"
File = "file"
TextFile = "text_file"
ExcelFile = "excel_file"
ImageFile = "image_file"
AwelFlow = "awel_flow"
class AgentResource(BaseModel):
@@ -80,7 +83,7 @@ class ResourceClient(ABC):
return ""
async def get_resource_prompt(
self, resource: AgentResource, question: Optional[str] = None
self, conv_uid, resource: AgentResource, question: Optional[str] = None
) -> str:
return resource.resource_prompt_template().format(
data_type=self.get_data_type(resource),