feat(agent): Optimize agent memory (#2665)

This commit is contained in:
Fangyin Cheng 2025-04-30 09:49:17 +08:00 committed by GitHub
parent b901cbc9a6
commit 3a00aca113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 10637 additions and 6023 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import time
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, final from typing import Any, Callable, Dict, List, Optional, Tuple, Type, final
@ -45,6 +46,7 @@ class ConversableAgent(Role, Agent):
bind_prompt: Optional[PromptTemplate] = None bind_prompt: Optional[PromptTemplate] = None
run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode") run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode")
max_retry_count: int = 3 max_retry_count: int = 3
max_timeout: int = 600
llm_client: Optional[AIWrapper] = None llm_client: Optional[AIWrapper] = None
# 确认当前Agent是否需要进行流式输出 # 确认当前Agent是否需要进行流式输出
stream_out: bool = True stream_out: bool = True
@ -363,6 +365,7 @@ class ConversableAgent(Role, Agent):
fail_reason = None fail_reason = None
current_retry_counter = 0 current_retry_counter = 0
start_time = time.time()
is_success = True is_success = True
observation = received_message.content or "" observation = received_message.content or ""
while current_retry_counter < self.max_retry_count: while current_retry_counter < self.max_retry_count:
@ -402,10 +405,12 @@ class ConversableAgent(Role, Agent):
thinking_messages, resource_info = await self._load_thinking_messages( thinking_messages, resource_info = await self._load_thinking_messages(
received_message=received_message, received_message=received_message,
sender=sender, sender=sender,
observation=observation,
rely_messages=rely_messages, rely_messages=rely_messages,
historical_dialogues=historical_dialogues, historical_dialogues=historical_dialogues,
context=reply_message.get_dict_context(), context=reply_message.get_dict_context(),
is_retry_chat=is_retry_chat, is_retry_chat=is_retry_chat,
current_retry_counter=current_retry_counter,
) )
with root_tracer.start_span( with root_tracer.start_span(
"agent.generate_reply.thinking", "agent.generate_reply.thinking",
@ -493,6 +498,7 @@ class ConversableAgent(Role, Agent):
logger.warning("No retry available!") logger.warning("No retry available!")
break break
fail_reason = reason fail_reason = reason
observation = fail_reason
await self.write_memories( await self.write_memories(
question=question, question=question,
ai_message=ai_message, ai_message=ai_message,
@ -514,6 +520,13 @@ class ConversableAgent(Role, Agent):
if self.run_mode != AgentRunMode.LOOP or act_out.terminate: if self.run_mode != AgentRunMode.LOOP or act_out.terminate:
logger.debug(f"Agent {self.name} reply success!{reply_message}") logger.debug(f"Agent {self.name} reply success!{reply_message}")
break break
time_cost = time.time() - start_time
if time_cost > self.max_timeout:
logger.warning(
f"Agent {self.name} run time out!{time_cost} > "
f"{self.max_timeout}"
)
break
# Continue to run the next round # Continue to run the next round
current_retry_counter += 1 current_retry_counter += 1
@ -1072,15 +1085,25 @@ class ConversableAgent(Role, Agent):
self, self,
received_message: AgentMessage, received_message: AgentMessage,
sender: Agent, sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None, rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None, historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False, is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]: ) -> Tuple[List[AgentMessage], Optional[Dict]]:
observation = received_message.content question = received_message.content
if not observation: observation = observation or question
if not question:
raise ValueError("The received message content is empty!") raise ValueError("The received message content is empty!")
most_recent_memories = ""
memory_list = []
# Read the memories according to the current observation
memories = await self.read_memories(observation) memories = await self.read_memories(observation)
if isinstance(memories, list):
memory_list = memories
else:
most_recent_memories = memories
has_memories = True if memories else False has_memories = True if memories else False
reply_message_str = "" reply_message_str = ""
if context is None: if context is None:
@ -1102,8 +1125,9 @@ class ConversableAgent(Role, Agent):
elif message.role == ModelMessageRoleType.AI: elif message.role == ModelMessageRoleType.AI:
reply_message_str += f"Observation: {message.content}\n" reply_message_str += f"Observation: {message.content}\n"
if reply_message_str: if reply_message_str:
memories += "\n" + reply_message_str most_recent_memories += "\n" + reply_message_str
try: try:
# Load the resource prompt according to the current observation
resource_prompt_str, resource_references = await self.load_resource( resource_prompt_str, resource_references = await self.load_resource(
observation, is_retry_chat=is_retry_chat observation, is_retry_chat=is_retry_chat
) )
@ -1114,21 +1138,19 @@ class ConversableAgent(Role, Agent):
resource_vars = await self.generate_resource_variables(resource_prompt_str) resource_vars = await self.generate_resource_variables(resource_prompt_str)
system_prompt = await self.build_system_prompt( system_prompt = await self.build_system_prompt(
question=observation, question=question,
most_recent_memories=memories, most_recent_memories=most_recent_memories,
resource_vars=resource_vars, resource_vars=resource_vars,
context=context, context=context,
is_retry_chat=is_retry_chat, is_retry_chat=is_retry_chat,
) )
user_prompt = await self.build_prompt( user_prompt = await self.build_prompt(
question=observation, question=question,
is_system=False, is_system=False,
most_recent_memories=memories, most_recent_memories=most_recent_memories,
resource_vars=resource_vars, resource_vars=resource_vars,
**context, **context,
) )
if not user_prompt:
user_prompt = f"Observation: {observation}"
agent_messages = [] agent_messages = []
if system_prompt: if system_prompt:
@ -1153,14 +1175,21 @@ class ConversableAgent(Role, Agent):
message.role = ModelMessageRoleType.AI message.role = ModelMessageRoleType.AI
agent_messages.append(message) agent_messages.append(message)
if memory_list:
agent_messages.extend(memory_list)
# Current user input information # Current user input information
if not user_prompt and (not memory_list or not current_retry_counter):
# The user prompt is empty, and the current retry count is 0 or the memory
# is empty
user_prompt = f"Observation: {observation}"
if user_prompt:
agent_messages.append( agent_messages.append(
AgentMessage( AgentMessage(
content=user_prompt, content=user_prompt,
role=ModelMessageRoleType.HUMAN, role=ModelMessageRoleType.HUMAN,
) )
) )
return agent_messages, resource_references return agent_messages, resource_references

View File

@ -160,10 +160,12 @@ class ManagerAgent(ConversableAgent, Team):
self, self,
received_message: AgentMessage, received_message: AgentMessage,
sender: Agent, sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None, rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None, historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False, is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]: ) -> Tuple[List[AgentMessage], Optional[Dict]]:
"""Load messages for thinking.""" """Load messages for thinking."""
return [AgentMessage(content=received_message.content)], None return [AgentMessage(content=received_message.content)], None

View File

@ -1,6 +1,10 @@
"""Memory module for the agent.""" """Memory module for the agent."""
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401 from .agent_memory import ( # noqa: F401
AgentMemory,
AgentMemoryFragment,
StructuredAgentMemoryFragment,
)
from .base import ( # noqa: F401 from .base import ( # noqa: F401
ImportanceScorer, ImportanceScorer,
InsightExtractor, InsightExtractor,

View File

@ -1,7 +1,11 @@
"""Agent memory module.""" """Agent memory module."""
import json
import logging
from datetime import datetime from datetime import datetime
from typing import Callable, List, Optional, Type, cast from typing import Callable, List, Optional, Type, Union, cast
from typing_extensions import TypedDict
from dbgpt.core import LLMClient from dbgpt.core import LLMClient
from dbgpt.util.annotations import immutable, mutable from dbgpt.util.annotations import immutable, mutable
@ -18,6 +22,18 @@ from .base import (
) )
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
logger = logging.getLogger(__name__)
class StructuredObservation(TypedDict):
"""Structured observation for agent memory."""
question: Optional[str]
thought: Optional[str]
action: Optional[str]
action_input: Optional[str]
observation: Optional[str]
class AgentMemoryFragment(MemoryFragment): class AgentMemoryFragment(MemoryFragment):
"""Default memory fragment for agent memory.""" """Default memory fragment for agent memory."""
@ -168,6 +184,94 @@ class AgentMemoryFragment(MemoryFragment):
) )
class StructuredAgentMemoryFragment(AgentMemoryFragment):
"""Structured memory fragment for agent memory."""
def __init__(
self,
observation: Union[StructuredObservation, List[StructuredObservation]],
embeddings: Optional[List[float]] = None,
memory_id: Optional[int] = None,
importance: Optional[float] = None,
last_accessed_time: Optional[datetime] = None,
is_insight: bool = False,
):
"""Create a structured memory fragment."""
super().__init__(
observation=self.to_raw_observation(observation),
embeddings=embeddings,
memory_id=memory_id,
importance=importance,
last_accessed_time=last_accessed_time,
is_insight=is_insight,
)
self._structured_observation = observation
def to_raw_observation(
self, observation: Union[StructuredObservation, List[StructuredObservation]]
) -> str:
"""Convert the structured observation to a raw observation.
Args:
observation(StructuredObservation): Structured observation
Returns:
str: Raw observation
"""
return json.dumps(observation, ensure_ascii=False)
@classmethod
def build_from(
cls: Type["AgentMemoryFragment"],
observation: Union[str, StructuredObservation],
embeddings: Optional[List[float]] = None,
memory_id: Optional[int] = None,
importance: Optional[float] = None,
is_insight: bool = False,
last_accessed_time: Optional[datetime] = None,
**kwargs,
) -> "AgentMemoryFragment":
"""Build a memory fragment from the given parameters."""
if isinstance(observation, str):
observation = json.loads(observation)
return cls(
observation=observation,
embeddings=embeddings,
memory_id=memory_id,
importance=importance,
last_accessed_time=last_accessed_time,
is_insight=is_insight,
)
def reduce(
self, memory_fragments: List["StructuredAgentMemoryFragment"], **kwargs
) -> "StructuredAgentMemoryFragment":
"""Reduce memory fragments to a single memory fragment.
Args:
memory_fragments(List[T]): Memory fragments
Returns:
T: The reduced memory fragment
"""
if len(memory_fragments) == 0:
raise ValueError("Memory fragments is empty.")
if len(memory_fragments) == 1:
return memory_fragments[0]
obs = []
for memory_fragment in memory_fragments:
try:
obs.append(json.loads(memory_fragment.raw_observation))
except Exception as e:
logger.warning(
"Failed to parse observation %s: %s",
memory_fragment.raw_observation,
e,
)
return self.current_class.build_from(obs, **kwargs) # type: ignore
class AgentMemory(Memory[AgentMemoryFragment]): class AgentMemory(Memory[AgentMemoryFragment]):
"""Agent memory.""" """Agent memory."""

View File

@ -2,7 +2,7 @@
from abc import ABC from abc import ABC
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from jinja2 import Environment, Template, meta from jinja2 import Environment, Template, meta
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
@ -10,10 +10,17 @@ from jinja2.sandbox import SandboxedEnvironment
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from .action.base import ActionOutput from .action.base import ActionOutput
from .memory.agent_memory import AgentMemory, AgentMemoryFragment from .memory.agent_memory import (
AgentMemory,
AgentMemoryFragment,
StructuredAgentMemoryFragment,
)
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
from .profile import Profile, ProfileConfig from .profile import Profile, ProfileConfig
if TYPE_CHECKING:
from .agent import AgentMessage
class AgentRunMode(str, Enum): class AgentRunMode(str, Enum):
"""Agent run mode.""" """Agent run mode."""
@ -210,10 +217,15 @@ class Role(ABC, BaseModel):
""" """
return None return None
@property
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
"""Return the memory fragment class."""
return AgentMemoryFragment
async def read_memories( async def read_memories(
self, self,
question: str, question: str,
) -> str: ) -> Union[str, List["AgentMessage"]]:
"""Read the memories from the memory.""" """Read the memories from the memory."""
memories = await self.memory.read(question) memories = await self.memory.read(question)
recent_messages = [m.raw_observation for m in memories] recent_messages = [m.raw_observation for m in memories]
@ -239,6 +251,7 @@ class Role(ABC, BaseModel):
action_output(ActionOutput): The action output. action_output(ActionOutput): The action output.
check_pass(bool): Whether the check pass. check_pass(bool): Whether the check pass.
check_fail_reason(str): The check fail reason. check_fail_reason(str): The check fail reason.
current_retry_counter(int): The current retry counter.
Returns: Returns:
AgentMemoryFragment: The memory fragment created. AgentMemoryFragment: The memory fragment created.
@ -247,17 +260,29 @@ class Role(ABC, BaseModel):
raise ValueError("Action output is required to save to memory.") raise ValueError("Action output is required to save to memory.")
mem_thoughts = action_output.thoughts or ai_message mem_thoughts = action_output.thoughts or ai_message
observation = action_output.observations action = action_output.action
action_input = action_output.action_input
observation = check_fail_reason or action_output.observations
memory_map = { memory_map = {
"question": question,
"thought": mem_thoughts, "thought": mem_thoughts,
"action": check_fail_reason, "action": action,
"observation": observation, "observation": observation,
} }
if action_input:
memory_map["action_input"] = action_input
if current_retry_counter is not None and current_retry_counter == 0:
memory_map["question"] = question
write_memory_template = self.write_memory_template write_memory_template = self.write_memory_template
memory_content = self._render_template(write_memory_template, **memory_map) memory_content = self._render_template(write_memory_template, **memory_map)
fragment = AgentMemoryFragment(memory_content)
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
if issubclass(fragment_cls, StructuredAgentMemoryFragment):
fragment = fragment_cls(memory_map)
else:
fragment = fragment_cls(memory_content)
await self.memory.write(fragment) await self.memory.write(fragment)
action_output.memory_fragments = { action_output.memory_fragments = {
@ -270,9 +295,10 @@ class Role(ABC, BaseModel):
async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None: async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None:
"""Recover the memory from the action outputs.""" """Recover the memory from the action outputs."""
fragments = [] fragments = []
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
for action_output in action_outputs: for action_output in action_outputs:
if action_output.memory_fragments: if action_output.memory_fragments:
fragment = AgentMemoryFragment.build_from( fragment = fragment_cls.build_from(
observation=action_output.memory_fragments["memory"], observation=action_output.memory_fragments["memory"],
importance=action_output.memory_fragments.get("importance"), importance=action_output.memory_fragments.get("importance"),
memory_id=action_output.memory_fragments.get("id"), memory_id=action_output.memory_fragments.get("id"),

View File

@ -1,5 +1,6 @@
import json
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Type, Union
from dbgpt._private.pydantic import Field from dbgpt._private.pydantic import Field
from dbgpt.agent import ( from dbgpt.agent import (
@ -11,12 +12,14 @@ from dbgpt.agent import (
ProfileConfig, ProfileConfig,
Resource, Resource,
ResourceType, ResourceType,
StructuredAgentMemoryFragment,
) )
from dbgpt.agent.core.role import AgentRunMode from dbgpt.agent.core.role import AgentRunMode
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
from dbgpt.agent.util.react_parser import ReActOutputParser from dbgpt.agent.util.react_parser import ReActOutputParser
from dbgpt.util.configure import DynConfig from dbgpt.util.configure import DynConfig
from ...core import ModelMessageRoleType
from .actions.react_action import ReActAction, Terminate from .actions.react_action import ReActAction, Terminate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,12 +66,9 @@ Please Solve this task:
Please answer in the same language as the user's question. Please answer in the same language as the user's question.
The current time is: {{ now_time }}. The current time is: {{ now_time }}.
""" """
_REACT_USER_TEMPLATE = """\
{% if most_recent_memories %}\ # Not needed additional user prompt template
Most recent message: _REACT_USER_TEMPLATE = """"""
{{ most_recent_memories }}
{% endif %}\
"""
_REACT_WRITE_MEMORY_TEMPLATE = """\ _REACT_WRITE_MEMORY_TEMPLATE = """\
@ -225,7 +225,10 @@ class ReActAgent(ConversableAgent):
steps = self.parser.parse(message_content) steps = self.parser.parse(message_content)
err_msg = None err_msg = None
if not steps: if not steps:
err_msg = "No correct response found." err_msg = (
"No correct response found. Please check your response, which must"
" be in the format indicated in the system prompt."
)
elif len(steps) != 1: elif len(steps) != 1:
err_msg = "Only one action is allowed each time." err_msg = "Only one action is allowed each time."
if err_msg: if err_msg:
@ -243,56 +246,72 @@ class ReActAgent(ConversableAgent):
) )
return action_output return action_output
async def write_memories( @property
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
"""Return the memory fragment class."""
return StructuredAgentMemoryFragment
async def read_memories(
self, self,
question: str, observation: str,
ai_message: str, ) -> Union[str, List["AgentMessage"]]:
action_output: Optional[ActionOutput] = None, memories = await self.memory.read(observation)
check_pass: bool = True, not_json_memories = []
check_fail_reason: Optional[str] = None, messages = []
current_retry_counter: Optional[int] = None, structured_memories = []
) -> AgentMemoryFragment: for m in memories:
"""Write the memories to the memory. if m.raw_observation:
try:
mem_dict = json.loads(m.raw_observation)
if isinstance(mem_dict, dict):
structured_memories.append(mem_dict)
elif isinstance(mem_dict, list):
structured_memories.extend(mem_dict)
else:
raise ValueError("Invalid memory format.")
except Exception:
not_json_memories.append(m.raw_observation)
We suggest you to override this method to save the conversation to memory for mem_dict in structured_memories:
according to your needs. question = mem_dict.get("question")
thought = mem_dict.get("thought")
Args: action = mem_dict.get("action")
question(str): The question received. action_input = mem_dict.get("action_input")
ai_message(str): The AI message, LLM output. observation = mem_dict.get("observation")
action_output(ActionOutput): The action output. if question:
check_pass(bool): Whether the check pass. messages.append(
check_fail_reason(str): The check fail reason. AgentMessage(
content=f"Question: {question}",
Returns: role=ModelMessageRoleType.HUMAN,
AgentMemoryFragment: The memory fragment created. )
""" )
if not action_output: ai_content = []
raise ValueError("Action output is required to save to memory.") if thought:
ai_content.append(f"Thought: {thought}")
mem_thoughts = action_output.thoughts or ai_message if action:
action = action_output.action ai_content.append(f"Action: {action}")
action_input = action_output.action_input
observation = check_fail_reason or action_output.observations
memory_map = {
"thought": mem_thoughts,
"action": action,
"observation": observation,
}
if action_input: if action_input:
memory_map["action_input"] = action_input ai_content.append(f"Action Input: {action_input}")
messages.append(
AgentMessage(
content="\n".join(ai_content),
role=ModelMessageRoleType.AI,
)
)
if current_retry_counter is not None and current_retry_counter == 0: if observation:
memory_map["question"] = question messages.append(
AgentMessage(
content=f"Observation: {observation}",
role=ModelMessageRoleType.HUMAN,
)
)
write_memory_template = self.write_memory_template if not messages and not_json_memories:
memory_content = self._render_template(write_memory_template, **memory_map) messages.append(
fragment = AgentMemoryFragment(memory_content) AgentMessage(
await self.memory.write(fragment) content="\n".join(not_json_memories),
action_output.memory_fragments = { role=ModelMessageRoleType.HUMAN,
"memory": fragment.raw_observation, )
"id": fragment.id, )
"importance": fragment.importance, return messages
}
return fragment

View File

@ -67,10 +67,12 @@ class StartAppAssistantAgent(ConversableAgent):
self, self,
received_message: AgentMessage, received_message: AgentMessage,
sender: Agent, sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None, rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None, historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False, is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]: ) -> Tuple[List[AgentMessage], Optional[Dict]]:
if rely_messages and len(rely_messages) > 0: if rely_messages and len(rely_messages) > 0:
return rely_messages[-1:], None return rely_messages[-1:], None

16308
uv.lock

File diff suppressed because one or more lines are too long