feat(agent): Optimize agent memory (#2665)

This commit is contained in:
Fangyin Cheng 2025-04-30 09:49:17 +08:00 committed by GitHub
parent b901cbc9a6
commit 3a00aca113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 10637 additions and 6023 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import time
from concurrent.futures import Executor, ThreadPoolExecutor
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, final
@ -45,6 +46,7 @@ class ConversableAgent(Role, Agent):
bind_prompt: Optional[PromptTemplate] = None
run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode")
max_retry_count: int = 3
max_timeout: int = 600
llm_client: Optional[AIWrapper] = None
# 确认当前Agent是否需要进行流式输出
stream_out: bool = True
@ -363,6 +365,7 @@ class ConversableAgent(Role, Agent):
fail_reason = None
current_retry_counter = 0
start_time = time.time()
is_success = True
observation = received_message.content or ""
while current_retry_counter < self.max_retry_count:
@ -402,10 +405,12 @@ class ConversableAgent(Role, Agent):
thinking_messages, resource_info = await self._load_thinking_messages(
received_message=received_message,
sender=sender,
observation=observation,
rely_messages=rely_messages,
historical_dialogues=historical_dialogues,
context=reply_message.get_dict_context(),
is_retry_chat=is_retry_chat,
current_retry_counter=current_retry_counter,
)
with root_tracer.start_span(
"agent.generate_reply.thinking",
@ -493,6 +498,7 @@ class ConversableAgent(Role, Agent):
logger.warning("No retry available!")
break
fail_reason = reason
observation = fail_reason
await self.write_memories(
question=question,
ai_message=ai_message,
@ -514,6 +520,13 @@ class ConversableAgent(Role, Agent):
if self.run_mode != AgentRunMode.LOOP or act_out.terminate:
logger.debug(f"Agent {self.name} reply success!{reply_message}")
break
time_cost = time.time() - start_time
if time_cost > self.max_timeout:
logger.warning(
f"Agent {self.name} run time out!{time_cost} > "
f"{self.max_timeout}"
)
break
# Continue to run the next round
current_retry_counter += 1
@ -1072,15 +1085,25 @@ class ConversableAgent(Role, Agent):
self,
received_message: AgentMessage,
sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]:
observation = received_message.content
if not observation:
question = received_message.content
observation = observation or question
if not question:
raise ValueError("The received message content is empty!")
most_recent_memories = ""
memory_list = []
# Read the memories according to the current observation
memories = await self.read_memories(observation)
if isinstance(memories, list):
memory_list = memories
else:
most_recent_memories = memories
has_memories = True if memories else False
reply_message_str = ""
if context is None:
@ -1102,8 +1125,9 @@ class ConversableAgent(Role, Agent):
elif message.role == ModelMessageRoleType.AI:
reply_message_str += f"Observation: {message.content}\n"
if reply_message_str:
memories += "\n" + reply_message_str
most_recent_memories += "\n" + reply_message_str
try:
# Load the resource prompt according to the current observation
resource_prompt_str, resource_references = await self.load_resource(
observation, is_retry_chat=is_retry_chat
)
@ -1114,21 +1138,19 @@ class ConversableAgent(Role, Agent):
resource_vars = await self.generate_resource_variables(resource_prompt_str)
system_prompt = await self.build_system_prompt(
question=observation,
most_recent_memories=memories,
question=question,
most_recent_memories=most_recent_memories,
resource_vars=resource_vars,
context=context,
is_retry_chat=is_retry_chat,
)
user_prompt = await self.build_prompt(
question=observation,
question=question,
is_system=False,
most_recent_memories=memories,
most_recent_memories=most_recent_memories,
resource_vars=resource_vars,
**context,
)
if not user_prompt:
user_prompt = f"Observation: {observation}"
agent_messages = []
if system_prompt:
@ -1153,14 +1175,21 @@ class ConversableAgent(Role, Agent):
message.role = ModelMessageRoleType.AI
agent_messages.append(message)
# Current user input information
agent_messages.append(
AgentMessage(
content=user_prompt,
role=ModelMessageRoleType.HUMAN,
)
)
if memory_list:
agent_messages.extend(memory_list)
# Current user input information
if not user_prompt and (not memory_list or not current_retry_counter):
# The user prompt is empty, and the current retry count is 0 or the memory
# is empty
user_prompt = f"Observation: {observation}"
if user_prompt:
agent_messages.append(
AgentMessage(
content=user_prompt,
role=ModelMessageRoleType.HUMAN,
)
)
return agent_messages, resource_references

View File

@ -160,10 +160,12 @@ class ManagerAgent(ConversableAgent, Team):
self,
received_message: AgentMessage,
sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]:
"""Load messages for thinking."""
return [AgentMessage(content=received_message.content)], None

View File

@ -1,6 +1,10 @@
"""Memory module for the agent."""
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401
from .agent_memory import ( # noqa: F401
AgentMemory,
AgentMemoryFragment,
StructuredAgentMemoryFragment,
)
from .base import ( # noqa: F401
ImportanceScorer,
InsightExtractor,

View File

@ -1,7 +1,11 @@
"""Agent memory module."""
import json
import logging
from datetime import datetime
from typing import Callable, List, Optional, Type, cast
from typing import Callable, List, Optional, Type, Union, cast
from typing_extensions import TypedDict
from dbgpt.core import LLMClient
from dbgpt.util.annotations import immutable, mutable
@ -18,6 +22,18 @@ from .base import (
)
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
logger = logging.getLogger(__name__)
class StructuredObservation(TypedDict):
"""Structured observation for agent memory."""
question: Optional[str]
thought: Optional[str]
action: Optional[str]
action_input: Optional[str]
observation: Optional[str]
class AgentMemoryFragment(MemoryFragment):
"""Default memory fragment for agent memory."""
@ -168,6 +184,94 @@ class AgentMemoryFragment(MemoryFragment):
)
class StructuredAgentMemoryFragment(AgentMemoryFragment):
"""Structured memory fragment for agent memory."""
def __init__(
self,
observation: Union[StructuredObservation, List[StructuredObservation]],
embeddings: Optional[List[float]] = None,
memory_id: Optional[int] = None,
importance: Optional[float] = None,
last_accessed_time: Optional[datetime] = None,
is_insight: bool = False,
):
"""Create a structured memory fragment."""
super().__init__(
observation=self.to_raw_observation(observation),
embeddings=embeddings,
memory_id=memory_id,
importance=importance,
last_accessed_time=last_accessed_time,
is_insight=is_insight,
)
self._structured_observation = observation
def to_raw_observation(
self, observation: Union[StructuredObservation, List[StructuredObservation]]
) -> str:
"""Convert the structured observation to a raw observation.
Args:
observation(StructuredObservation): Structured observation
Returns:
str: Raw observation
"""
return json.dumps(observation, ensure_ascii=False)
@classmethod
def build_from(
cls: Type["AgentMemoryFragment"],
observation: Union[str, StructuredObservation],
embeddings: Optional[List[float]] = None,
memory_id: Optional[int] = None,
importance: Optional[float] = None,
is_insight: bool = False,
last_accessed_time: Optional[datetime] = None,
**kwargs,
) -> "AgentMemoryFragment":
"""Build a memory fragment from the given parameters."""
if isinstance(observation, str):
observation = json.loads(observation)
return cls(
observation=observation,
embeddings=embeddings,
memory_id=memory_id,
importance=importance,
last_accessed_time=last_accessed_time,
is_insight=is_insight,
)
def reduce(
self, memory_fragments: List["StructuredAgentMemoryFragment"], **kwargs
) -> "StructuredAgentMemoryFragment":
"""Reduce memory fragments to a single memory fragment.
Args:
memory_fragments(List[T]): Memory fragments
Returns:
T: The reduced memory fragment
"""
if len(memory_fragments) == 0:
raise ValueError("Memory fragments is empty.")
if len(memory_fragments) == 1:
return memory_fragments[0]
obs = []
for memory_fragment in memory_fragments:
try:
obs.append(json.loads(memory_fragment.raw_observation))
except Exception as e:
logger.warning(
"Failed to parse observation %s: %s",
memory_fragment.raw_observation,
e,
)
return self.current_class.build_from(obs, **kwargs) # type: ignore
class AgentMemory(Memory[AgentMemoryFragment]):
"""Agent memory."""

View File

@ -2,7 +2,7 @@
from abc import ABC
from enum import Enum
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from jinja2 import Environment, Template, meta
from jinja2.sandbox import SandboxedEnvironment
@ -10,10 +10,17 @@ from jinja2.sandbox import SandboxedEnvironment
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from .action.base import ActionOutput
from .memory.agent_memory import AgentMemory, AgentMemoryFragment
from .memory.agent_memory import (
AgentMemory,
AgentMemoryFragment,
StructuredAgentMemoryFragment,
)
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
from .profile import Profile, ProfileConfig
if TYPE_CHECKING:
from .agent import AgentMessage
class AgentRunMode(str, Enum):
"""Agent run mode."""
@ -210,10 +217,15 @@ class Role(ABC, BaseModel):
"""
return None
@property
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
"""Return the memory fragment class."""
return AgentMemoryFragment
async def read_memories(
self,
question: str,
) -> str:
) -> Union[str, List["AgentMessage"]]:
"""Read the memories from the memory."""
memories = await self.memory.read(question)
recent_messages = [m.raw_observation for m in memories]
@ -239,6 +251,7 @@ class Role(ABC, BaseModel):
action_output(ActionOutput): The action output.
check_pass(bool): Whether the check pass.
check_fail_reason(str): The check fail reason.
current_retry_counter(int): The current retry counter.
Returns:
AgentMemoryFragment: The memory fragment created.
@ -247,17 +260,29 @@ class Role(ABC, BaseModel):
raise ValueError("Action output is required to save to memory.")
mem_thoughts = action_output.thoughts or ai_message
observation = action_output.observations
action = action_output.action
action_input = action_output.action_input
observation = check_fail_reason or action_output.observations
memory_map = {
"question": question,
"thought": mem_thoughts,
"action": check_fail_reason,
"action": action,
"observation": observation,
}
if action_input:
memory_map["action_input"] = action_input
if current_retry_counter is not None and current_retry_counter == 0:
memory_map["question"] = question
write_memory_template = self.write_memory_template
memory_content = self._render_template(write_memory_template, **memory_map)
fragment = AgentMemoryFragment(memory_content)
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
if issubclass(fragment_cls, StructuredAgentMemoryFragment):
fragment = fragment_cls(memory_map)
else:
fragment = fragment_cls(memory_content)
await self.memory.write(fragment)
action_output.memory_fragments = {
@ -270,9 +295,10 @@ class Role(ABC, BaseModel):
async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None:
"""Recover the memory from the action outputs."""
fragments = []
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
for action_output in action_outputs:
if action_output.memory_fragments:
fragment = AgentMemoryFragment.build_from(
fragment = fragment_cls.build_from(
observation=action_output.memory_fragments["memory"],
importance=action_output.memory_fragments.get("importance"),
memory_id=action_output.memory_fragments.get("id"),

View File

@ -1,5 +1,6 @@
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type, Union
from dbgpt._private.pydantic import Field
from dbgpt.agent import (
@ -11,12 +12,14 @@ from dbgpt.agent import (
ProfileConfig,
Resource,
ResourceType,
StructuredAgentMemoryFragment,
)
from dbgpt.agent.core.role import AgentRunMode
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
from dbgpt.agent.util.react_parser import ReActOutputParser
from dbgpt.util.configure import DynConfig
from ...core import ModelMessageRoleType
from .actions.react_action import ReActAction, Terminate
logger = logging.getLogger(__name__)
@ -63,12 +66,9 @@ Please Solve this task:
Please answer in the same language as the user's question.
The current time is: {{ now_time }}.
"""
_REACT_USER_TEMPLATE = """\
{% if most_recent_memories %}\
Most recent message:
{{ most_recent_memories }}
{% endif %}\
"""
# Not needed additional user prompt template
_REACT_USER_TEMPLATE = """"""
_REACT_WRITE_MEMORY_TEMPLATE = """\
@ -225,7 +225,10 @@ class ReActAgent(ConversableAgent):
steps = self.parser.parse(message_content)
err_msg = None
if not steps:
err_msg = "No correct response found."
err_msg = (
"No correct response found. Please check your response, which must"
" be in the format indicated in the system prompt."
)
elif len(steps) != 1:
err_msg = "Only one action is allowed each time."
if err_msg:
@ -243,56 +246,72 @@ class ReActAgent(ConversableAgent):
)
return action_output
async def write_memories(
@property
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
"""Return the memory fragment class."""
return StructuredAgentMemoryFragment
async def read_memories(
self,
question: str,
ai_message: str,
action_output: Optional[ActionOutput] = None,
check_pass: bool = True,
check_fail_reason: Optional[str] = None,
current_retry_counter: Optional[int] = None,
) -> AgentMemoryFragment:
"""Write the memories to the memory.
observation: str,
) -> Union[str, List["AgentMessage"]]:
memories = await self.memory.read(observation)
not_json_memories = []
messages = []
structured_memories = []
for m in memories:
if m.raw_observation:
try:
mem_dict = json.loads(m.raw_observation)
if isinstance(mem_dict, dict):
structured_memories.append(mem_dict)
elif isinstance(mem_dict, list):
structured_memories.extend(mem_dict)
else:
raise ValueError("Invalid memory format.")
except Exception:
not_json_memories.append(m.raw_observation)
We suggest you to override this method to save the conversation to memory
according to your needs.
for mem_dict in structured_memories:
question = mem_dict.get("question")
thought = mem_dict.get("thought")
action = mem_dict.get("action")
action_input = mem_dict.get("action_input")
observation = mem_dict.get("observation")
if question:
messages.append(
AgentMessage(
content=f"Question: {question}",
role=ModelMessageRoleType.HUMAN,
)
)
ai_content = []
if thought:
ai_content.append(f"Thought: {thought}")
if action:
ai_content.append(f"Action: {action}")
if action_input:
ai_content.append(f"Action Input: {action_input}")
messages.append(
AgentMessage(
content="\n".join(ai_content),
role=ModelMessageRoleType.AI,
)
)
Args:
question(str): The question received.
ai_message(str): The AI message, LLM output.
action_output(ActionOutput): The action output.
check_pass(bool): Whether the check pass.
check_fail_reason(str): The check fail reason.
if observation:
messages.append(
AgentMessage(
content=f"Observation: {observation}",
role=ModelMessageRoleType.HUMAN,
)
)
Returns:
AgentMemoryFragment: The memory fragment created.
"""
if not action_output:
raise ValueError("Action output is required to save to memory.")
mem_thoughts = action_output.thoughts or ai_message
action = action_output.action
action_input = action_output.action_input
observation = check_fail_reason or action_output.observations
memory_map = {
"thought": mem_thoughts,
"action": action,
"observation": observation,
}
if action_input:
memory_map["action_input"] = action_input
if current_retry_counter is not None and current_retry_counter == 0:
memory_map["question"] = question
write_memory_template = self.write_memory_template
memory_content = self._render_template(write_memory_template, **memory_map)
fragment = AgentMemoryFragment(memory_content)
await self.memory.write(fragment)
action_output.memory_fragments = {
"memory": fragment.raw_observation,
"id": fragment.id,
"importance": fragment.importance,
}
return fragment
if not messages and not_json_memories:
messages.append(
AgentMessage(
content="\n".join(not_json_memories),
role=ModelMessageRoleType.HUMAN,
)
)
return messages

View File

@ -67,10 +67,12 @@ class StartAppAssistantAgent(ConversableAgent):
self,
received_message: AgentMessage,
sender: Agent,
observation: Optional[str] = None,
rely_messages: Optional[List[AgentMessage]] = None,
historical_dialogues: Optional[List[AgentMessage]] = None,
context: Optional[Dict[str, Any]] = None,
is_retry_chat: bool = False,
current_retry_counter: Optional[int] = None,
) -> Tuple[List[AgentMessage], Optional[Dict]]:
if rely_messages and len(rely_messages) > 0:
return rely_messages[-1:], None

16308
uv.lock

File diff suppressed because one or more lines are too long