fix(agent): Fix agent loss message bug (#1283)

This commit is contained in:
明天
2024-03-14 14:38:10 +08:00
committed by GitHub
parent adaa68eb00
commit a207640ff2
10 changed files with 276 additions and 89 deletions

View File

@@ -10,7 +10,7 @@ from fastapi import APIRouter, Body
from fastapi.responses import StreamingResponse
from dbgpt._private.config import Config
from dbgpt.agent.agents.agent import Agent, AgentContext
from dbgpt.agent.agents.agent_new import Agent, AgentContext
from dbgpt.agent.agents.agents_manage import agent_manage
from dbgpt.agent.agents.base_agent_new import ConversableAgent
from dbgpt.agent.agents.llm.llm import LLMConfig, LLMStrategyType
@@ -121,7 +121,7 @@ class MultiAgents(BaseComponent, ABC):
)
)
asyncio.create_task(
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
user_query, agent_conv_id, gpt_app, is_retry_chat
)
@@ -129,17 +129,19 @@ class MultiAgents(BaseComponent, ABC):
async for chunk in multi_agents.chat_messages(agent_conv_id):
if chunk:
logger.info(chunk)
try:
chunk = json.dumps(
{"vis": chunk}, default=serialize, ensure_ascii=False
)
yield f"data: {chunk}\n\n"
if chunk is None or len(chunk) <= 0:
continue
resp = f"data:{chunk}\n\n"
yield task, resp
except Exception as e:
logger.exception(f"get messages {gpts_name} Exception!" + str(e))
yield f"data: {str(e)}\n\n"
yield f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n'
yield task, f'data:{json.dumps({"vis": "[DONE]"}, default=serialize, ensure_ascii=False)} \n\n'
async def app_agent_chat(
self,
@@ -164,19 +166,30 @@ class MultiAgents(BaseComponent, ABC):
current_message.save_to_storage()
current_message.start_new_round()
current_message.add_user_message(user_query)
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
agent_task = None
try:
agent_conv_id = conv_uid + "_" + str(current_message.chat_order)
async for chunk in multi_agents.agent_chat(
async for task, chunk in multi_agents.agent_chat(
agent_conv_id, gpts_name, user_query, user_code, sys_code
):
agent_task = task
yield chunk
final_message = await self.stable_message(agent_conv_id)
except asyncio.CancelledError:
# Client disconnects
print("Client disconnected")
if agent_task:
logger.info(f"Chat to App {gpts_name}:{agent_conv_id} Cancel!")
agent_task.cancel()
except Exception as e:
logger.exception(f"Chat to App {gpts_name} Failed!" + str(e))
raise
finally:
logger.info(f"save agent chat info{conv_uid},{agent_conv_id}")
current_message.add_view_message(final_message)
logger.info(f"save agent chat info{conv_uid}")
final_message = await self.stable_message(agent_conv_id)
if final_message:
current_message.add_view_message(final_message)
current_message.end_current_round()
current_message.save_to_storage()
@@ -288,7 +301,7 @@ class MultiAgents(BaseComponent, ABC):
]
else False
)
message = await self.memory.one_chat_competions(conv_id)
message = await self.memory.one_chat_competions_v2(conv_id)
yield message
if is_complete:
@@ -308,11 +321,12 @@ class MultiAgents(BaseComponent, ABC):
else False
)
if is_complete:
return await self.memory.one_chat_competions(conv_id)
return await self.memory.one_chat_competions_v2(conv_id)
else:
raise ValueError(
"The conversation has not been completed yet, so we cannot directly obtain information."
)
pass
# raise ValueError(
# "The conversation has not been completed yet, so we cannot directly obtain information."
# )
else:
raise ValueError("No conversation record found!")