fix(agent): Fix agent loss message bug (#1283)

This commit is contained in:
明天
2024-03-14 14:38:10 +08:00
committed by GitHub
parent adaa68eb00
commit a207640ff2
10 changed files with 276 additions and 89 deletions

View File

@@ -10,7 +10,7 @@ from fastapi import APIRouter, Body
from fastapi.responses import StreamingResponse
from dbgpt._private.config import Config
from dbgpt.agent.agents.agent import Agent, AgentContext
from dbgpt.agent.agents.agent_new import Agent, AgentContext
from dbgpt.agent.agents.agents_manage import agent_manage
from dbgpt.agent.agents.base_agent_new import ConversableAgent
from dbgpt.agent.agents.llm.llm import LLMConfig, LLMStrategyType
@@ -121,7 +121,7 @@ class MultiAgents(BaseComponent, ABC):
)
)
asyncio.create_task(
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
user_query, agent_conv_id, gpt_app, is_retry_chat
)
@@ -129,17 +129,19 @@ class MultiAgents(BaseComponent, ABC):
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
logger.info(chunk)
try:
chunk = json.dumps(
{"vis": chunk}, default=serialize, ensure_ascii=False
)
yield f"data: {chunk}\n\n"
if chunk is None or len(chunk) <= 0:
continue
resp = f"data:{chunk}\n\n"
yield task, resp
except Exception as e:
logger.exception(f"get messages {gpts_name} Exception!" + str(e))
yield f"data: {str(e)}\n\n"
yield f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n'
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n'
async def app_agent_chat(
self,
@@ -164,19 +166,30 @@ class MultiAgents(BaseComponent, ABC):
current_message.save_to_storage()
current_message.start_new_round()
current_message.add_user_message(user_query)
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
agent_task = None
try:
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
async for chunk in multi_agents.agent_chat(
async for task, chunk in multi_agents.agent_chat(
agent_conv_id, gpts_name, user_query, user_code, sys_code
):
agent_task = task
yield chunk
final_message = await self.stable_message(agent_conv_id)
except asyncio.CancelledError:
# Client disconnects
print("Client disconnected")
if agent_task:
logger.info(f"Chat to App {gpts_name}:{agent_conv_id} Cancel!")
agent_task.cancel()
except Exception as e:
logger.exception(f"Chat to App {gpts_name} Failed!" + str(e))
raise
finally:
logger.info(f"save agent chat info{conv_uid},{agent_conv_id}")
current_message.add_view_message(final_message)
logger.info(f"save agent chat info{conv_uid}")
final_message = await self.stable_message(agent_conv_id)
if final_message:
current_message.add_view_message(final_message)
current_message.end_current_round()
current_message.save_to_storage()
@@ -288,7 +301,7 @@ class MultiAgents(BaseComponent, ABC):
]
else False
)
message = await self.memory.one_chat_competions(conv_id)
message = await self.memory.one_chat_competions_v2(conv_id)
yield message
if is_complete:
@@ -308,11 +321,12 @@ class MultiAgents(BaseComponent, ABC):
else False
)
if is_complete:
return await self.memory.one_chat_competions(conv_id)
return await self.memory.one_chat_competions_v2(conv_id)
else:
raise ValueError(
"The conversation has not been completed yet, so we cannot directly obtain information."
)
pass
# raise ValueError(
# "The conversation has not been completed yet, so we cannot directly obtain information."
# )
else:
raise ValueError("No conversation record found!")

View File

@@ -27,14 +27,15 @@ class PluginHubLoadClient(ResourcePluginClient):
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
) -> PluginPromptGenerator:
logger.info(f"PluginHubLoadClient load plugin:{value}")
plugins_prompt_generator = PluginPromptGenerator()
plugins_prompt_generator.command_registry = CFG.command_registry
if plugin_generator is None:
plugin_generator = PluginPromptGenerator()
plugin_generator.command_registry = CFG.command_registry
agent_module = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
plugins_prompt_generator = agent_module.load_select_plugin(
plugins_prompt_generator, json.dumps(value)
plugin_generator = agent_module.load_select_plugin(
plugin_generator, json.dumps(value)
)
return plugins_prompt_generator
return plugin_generator

View File

@@ -1,7 +1,7 @@
from abc import ABC
from typing import Dict, List, Optional
from dbgpt.agent.agents.agent import Agent, AgentGenerateContext
from dbgpt.agent.agents.agent_new import Agent, AgentGenerateContext
from dbgpt.agent.agents.agents_manage import agent_manage
from dbgpt.agent.agents.base_agent_new import ConversableAgent
from dbgpt.agent.agents.llm.llm import LLMConfig
@@ -191,6 +191,7 @@ class AwelAgentOperator(
agent_context=input_value.agent_context,
resource_loader=input_value.resource_loader,
llm_client=input_value.llm_client,
round_index=agent.consecutive_auto_reply_counter,
)
async def get_agent(
@@ -208,11 +209,19 @@ class AwelAgentOperator(
llm_config = LLMConfig(llm_client=input_value.llm_client)
else:
llm_config = LLMConfig(llm_client=self.llm_client)
else:
if not llm_config.llm_client:
if input_value.llm_client:
llm_config.llm_client = input_value.llm_client
else:
llm_config.llm_client = self.llm_client
kwargs = {}
if self.awel_agent.role_name:
kwargs["name"] = self.awel_agent.role_name
if self.awel_agent.fixed_subgoal:
kwargs["fixed_subgoal"] = self.awel_agent.fixed_subgoal
agent = (
await agent_cls(**kwargs)
.bind(input_value.memory)
@@ -222,6 +231,7 @@ class AwelAgentOperator(
.bind(input_value.resource_loader)
.build()
)
return agent

View File

@@ -1,12 +1,12 @@
import json
import logging
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, validator
from dbgpt._private.config import Config
from dbgpt.agent.actions.action import ActionOutput, T
from dbgpt.agent.agents.agent import Agent, AgentContext, AgentGenerateContext
from dbgpt.agent.agents.agent_new import Agent, AgentContext, AgentGenerateContext
from dbgpt.agent.agents.base_agent_new import ConversableAgent
from dbgpt.agent.agents.base_team import ManagerAgent
from dbgpt.core.awel import DAG
@@ -35,6 +35,9 @@ class AwelLayoutChatNewManager(ManagerAgent):
assert value is not None and value != "", "dag must not be empty"
return value
async def _a_process_received_message(self, message: Optional[Dict], sender: Agent):
pass
async def a_act(
self,
message: Optional[str],
@@ -60,7 +63,7 @@ class AwelLayoutChatNewManager(ManagerAgent):
"content": message,
"current_goal": message,
},
sender=self,
sender=sender,
reviewer=reviewer,
memory=self.memory,
agent_context=self.agent_context,
@@ -73,8 +76,11 @@ class AwelLayoutChatNewManager(ManagerAgent):
last_message = final_generate_context.rely_messages[-1]
last_agent = await last_node.get_agent(final_generate_context)
last_agent.consecutive_auto_reply_counter = (
final_generate_context.round_index
)
await last_agent.a_send(
last_message, self, start_message_context.reviewer, False
last_message, sender, start_message_context.reviewer, False
)
return ActionOutput(

View File

@@ -35,7 +35,9 @@ class AutoPlanChatManager(ManagerAgent):
if now_plan.rely and len(now_plan.rely) > 0:
rely_tasks_list = now_plan.rely.split(",")
rely_tasks = self.memory.plans_memory.get_by_conv_id_and_num(conv_id, [])
rely_tasks = self.memory.plans_memory.get_by_conv_id_and_num(
conv_id, rely_tasks_list
)
if rely_tasks:
rely_prompt = "Read the result data of the dependent steps in the above historical message to complete the current goal:"
for rely_task in rely_tasks: