mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
feat(agent): Optimize agent memory (#2665)
This commit is contained in:
parent
b901cbc9a6
commit
3a00aca113
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, final
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, final
|
||||||
@ -45,6 +46,7 @@ class ConversableAgent(Role, Agent):
|
|||||||
bind_prompt: Optional[PromptTemplate] = None
|
bind_prompt: Optional[PromptTemplate] = None
|
||||||
run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode")
|
run_mode: Optional[AgentRunMode] = Field(default=None, description="Run mode")
|
||||||
max_retry_count: int = 3
|
max_retry_count: int = 3
|
||||||
|
max_timeout: int = 600
|
||||||
llm_client: Optional[AIWrapper] = None
|
llm_client: Optional[AIWrapper] = None
|
||||||
# 确认当前Agent是否需要进行流式输出
|
# 确认当前Agent是否需要进行流式输出
|
||||||
stream_out: bool = True
|
stream_out: bool = True
|
||||||
@ -363,6 +365,7 @@ class ConversableAgent(Role, Agent):
|
|||||||
|
|
||||||
fail_reason = None
|
fail_reason = None
|
||||||
current_retry_counter = 0
|
current_retry_counter = 0
|
||||||
|
start_time = time.time()
|
||||||
is_success = True
|
is_success = True
|
||||||
observation = received_message.content or ""
|
observation = received_message.content or ""
|
||||||
while current_retry_counter < self.max_retry_count:
|
while current_retry_counter < self.max_retry_count:
|
||||||
@ -402,10 +405,12 @@ class ConversableAgent(Role, Agent):
|
|||||||
thinking_messages, resource_info = await self._load_thinking_messages(
|
thinking_messages, resource_info = await self._load_thinking_messages(
|
||||||
received_message=received_message,
|
received_message=received_message,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
|
observation=observation,
|
||||||
rely_messages=rely_messages,
|
rely_messages=rely_messages,
|
||||||
historical_dialogues=historical_dialogues,
|
historical_dialogues=historical_dialogues,
|
||||||
context=reply_message.get_dict_context(),
|
context=reply_message.get_dict_context(),
|
||||||
is_retry_chat=is_retry_chat,
|
is_retry_chat=is_retry_chat,
|
||||||
|
current_retry_counter=current_retry_counter,
|
||||||
)
|
)
|
||||||
with root_tracer.start_span(
|
with root_tracer.start_span(
|
||||||
"agent.generate_reply.thinking",
|
"agent.generate_reply.thinking",
|
||||||
@ -493,6 +498,7 @@ class ConversableAgent(Role, Agent):
|
|||||||
logger.warning("No retry available!")
|
logger.warning("No retry available!")
|
||||||
break
|
break
|
||||||
fail_reason = reason
|
fail_reason = reason
|
||||||
|
observation = fail_reason
|
||||||
await self.write_memories(
|
await self.write_memories(
|
||||||
question=question,
|
question=question,
|
||||||
ai_message=ai_message,
|
ai_message=ai_message,
|
||||||
@ -514,6 +520,13 @@ class ConversableAgent(Role, Agent):
|
|||||||
if self.run_mode != AgentRunMode.LOOP or act_out.terminate:
|
if self.run_mode != AgentRunMode.LOOP or act_out.terminate:
|
||||||
logger.debug(f"Agent {self.name} reply success!{reply_message}")
|
logger.debug(f"Agent {self.name} reply success!{reply_message}")
|
||||||
break
|
break
|
||||||
|
time_cost = time.time() - start_time
|
||||||
|
if time_cost > self.max_timeout:
|
||||||
|
logger.warning(
|
||||||
|
f"Agent {self.name} run time out!{time_cost} > "
|
||||||
|
f"{self.max_timeout}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Continue to run the next round
|
# Continue to run the next round
|
||||||
current_retry_counter += 1
|
current_retry_counter += 1
|
||||||
@ -1072,15 +1085,25 @@ class ConversableAgent(Role, Agent):
|
|||||||
self,
|
self,
|
||||||
received_message: AgentMessage,
|
received_message: AgentMessage,
|
||||||
sender: Agent,
|
sender: Agent,
|
||||||
|
observation: Optional[str] = None,
|
||||||
rely_messages: Optional[List[AgentMessage]] = None,
|
rely_messages: Optional[List[AgentMessage]] = None,
|
||||||
historical_dialogues: Optional[List[AgentMessage]] = None,
|
historical_dialogues: Optional[List[AgentMessage]] = None,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
is_retry_chat: bool = False,
|
is_retry_chat: bool = False,
|
||||||
|
current_retry_counter: Optional[int] = None,
|
||||||
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
||||||
observation = received_message.content
|
question = received_message.content
|
||||||
if not observation:
|
observation = observation or question
|
||||||
|
if not question:
|
||||||
raise ValueError("The received message content is empty!")
|
raise ValueError("The received message content is empty!")
|
||||||
|
most_recent_memories = ""
|
||||||
|
memory_list = []
|
||||||
|
# Read the memories according to the current observation
|
||||||
memories = await self.read_memories(observation)
|
memories = await self.read_memories(observation)
|
||||||
|
if isinstance(memories, list):
|
||||||
|
memory_list = memories
|
||||||
|
else:
|
||||||
|
most_recent_memories = memories
|
||||||
has_memories = True if memories else False
|
has_memories = True if memories else False
|
||||||
reply_message_str = ""
|
reply_message_str = ""
|
||||||
if context is None:
|
if context is None:
|
||||||
@ -1102,8 +1125,9 @@ class ConversableAgent(Role, Agent):
|
|||||||
elif message.role == ModelMessageRoleType.AI:
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
reply_message_str += f"Observation: {message.content}\n"
|
reply_message_str += f"Observation: {message.content}\n"
|
||||||
if reply_message_str:
|
if reply_message_str:
|
||||||
memories += "\n" + reply_message_str
|
most_recent_memories += "\n" + reply_message_str
|
||||||
try:
|
try:
|
||||||
|
# Load the resource prompt according to the current observation
|
||||||
resource_prompt_str, resource_references = await self.load_resource(
|
resource_prompt_str, resource_references = await self.load_resource(
|
||||||
observation, is_retry_chat=is_retry_chat
|
observation, is_retry_chat=is_retry_chat
|
||||||
)
|
)
|
||||||
@ -1114,21 +1138,19 @@ class ConversableAgent(Role, Agent):
|
|||||||
resource_vars = await self.generate_resource_variables(resource_prompt_str)
|
resource_vars = await self.generate_resource_variables(resource_prompt_str)
|
||||||
|
|
||||||
system_prompt = await self.build_system_prompt(
|
system_prompt = await self.build_system_prompt(
|
||||||
question=observation,
|
question=question,
|
||||||
most_recent_memories=memories,
|
most_recent_memories=most_recent_memories,
|
||||||
resource_vars=resource_vars,
|
resource_vars=resource_vars,
|
||||||
context=context,
|
context=context,
|
||||||
is_retry_chat=is_retry_chat,
|
is_retry_chat=is_retry_chat,
|
||||||
)
|
)
|
||||||
user_prompt = await self.build_prompt(
|
user_prompt = await self.build_prompt(
|
||||||
question=observation,
|
question=question,
|
||||||
is_system=False,
|
is_system=False,
|
||||||
most_recent_memories=memories,
|
most_recent_memories=most_recent_memories,
|
||||||
resource_vars=resource_vars,
|
resource_vars=resource_vars,
|
||||||
**context,
|
**context,
|
||||||
)
|
)
|
||||||
if not user_prompt:
|
|
||||||
user_prompt = f"Observation: {observation}"
|
|
||||||
|
|
||||||
agent_messages = []
|
agent_messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@ -1153,14 +1175,21 @@ class ConversableAgent(Role, Agent):
|
|||||||
message.role = ModelMessageRoleType.AI
|
message.role = ModelMessageRoleType.AI
|
||||||
agent_messages.append(message)
|
agent_messages.append(message)
|
||||||
|
|
||||||
|
if memory_list:
|
||||||
|
agent_messages.extend(memory_list)
|
||||||
|
|
||||||
# Current user input information
|
# Current user input information
|
||||||
|
if not user_prompt and (not memory_list or not current_retry_counter):
|
||||||
|
# The user prompt is empty, and the current retry count is 0 or the memory
|
||||||
|
# is empty
|
||||||
|
user_prompt = f"Observation: {observation}"
|
||||||
|
if user_prompt:
|
||||||
agent_messages.append(
|
agent_messages.append(
|
||||||
AgentMessage(
|
AgentMessage(
|
||||||
content=user_prompt,
|
content=user_prompt,
|
||||||
role=ModelMessageRoleType.HUMAN,
|
role=ModelMessageRoleType.HUMAN,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return agent_messages, resource_references
|
return agent_messages, resource_references
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,10 +160,12 @@ class ManagerAgent(ConversableAgent, Team):
|
|||||||
self,
|
self,
|
||||||
received_message: AgentMessage,
|
received_message: AgentMessage,
|
||||||
sender: Agent,
|
sender: Agent,
|
||||||
|
observation: Optional[str] = None,
|
||||||
rely_messages: Optional[List[AgentMessage]] = None,
|
rely_messages: Optional[List[AgentMessage]] = None,
|
||||||
historical_dialogues: Optional[List[AgentMessage]] = None,
|
historical_dialogues: Optional[List[AgentMessage]] = None,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
is_retry_chat: bool = False,
|
is_retry_chat: bool = False,
|
||||||
|
current_retry_counter: Optional[int] = None,
|
||||||
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
||||||
"""Load messages for thinking."""
|
"""Load messages for thinking."""
|
||||||
return [AgentMessage(content=received_message.content)], None
|
return [AgentMessage(content=received_message.content)], None
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
"""Memory module for the agent."""
|
"""Memory module for the agent."""
|
||||||
|
|
||||||
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401
|
from .agent_memory import ( # noqa: F401
|
||||||
|
AgentMemory,
|
||||||
|
AgentMemoryFragment,
|
||||||
|
StructuredAgentMemoryFragment,
|
||||||
|
)
|
||||||
from .base import ( # noqa: F401
|
from .base import ( # noqa: F401
|
||||||
ImportanceScorer,
|
ImportanceScorer,
|
||||||
InsightExtractor,
|
InsightExtractor,
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
"""Agent memory module."""
|
"""Agent memory module."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, List, Optional, Type, cast
|
from typing import Callable, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from dbgpt.core import LLMClient
|
from dbgpt.core import LLMClient
|
||||||
from dbgpt.util.annotations import immutable, mutable
|
from dbgpt.util.annotations import immutable, mutable
|
||||||
@ -18,6 +22,18 @@ from .base import (
|
|||||||
)
|
)
|
||||||
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
|
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredObservation(TypedDict):
|
||||||
|
"""Structured observation for agent memory."""
|
||||||
|
|
||||||
|
question: Optional[str]
|
||||||
|
thought: Optional[str]
|
||||||
|
action: Optional[str]
|
||||||
|
action_input: Optional[str]
|
||||||
|
observation: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class AgentMemoryFragment(MemoryFragment):
|
class AgentMemoryFragment(MemoryFragment):
|
||||||
"""Default memory fragment for agent memory."""
|
"""Default memory fragment for agent memory."""
|
||||||
@ -168,6 +184,94 @@ class AgentMemoryFragment(MemoryFragment):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredAgentMemoryFragment(AgentMemoryFragment):
|
||||||
|
"""Structured memory fragment for agent memory."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation: Union[StructuredObservation, List[StructuredObservation]],
|
||||||
|
embeddings: Optional[List[float]] = None,
|
||||||
|
memory_id: Optional[int] = None,
|
||||||
|
importance: Optional[float] = None,
|
||||||
|
last_accessed_time: Optional[datetime] = None,
|
||||||
|
is_insight: bool = False,
|
||||||
|
):
|
||||||
|
"""Create a structured memory fragment."""
|
||||||
|
super().__init__(
|
||||||
|
observation=self.to_raw_observation(observation),
|
||||||
|
embeddings=embeddings,
|
||||||
|
memory_id=memory_id,
|
||||||
|
importance=importance,
|
||||||
|
last_accessed_time=last_accessed_time,
|
||||||
|
is_insight=is_insight,
|
||||||
|
)
|
||||||
|
self._structured_observation = observation
|
||||||
|
|
||||||
|
def to_raw_observation(
|
||||||
|
self, observation: Union[StructuredObservation, List[StructuredObservation]]
|
||||||
|
) -> str:
|
||||||
|
"""Convert the structured observation to a raw observation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation(StructuredObservation): Structured observation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Raw observation
|
||||||
|
"""
|
||||||
|
return json.dumps(observation, ensure_ascii=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_from(
|
||||||
|
cls: Type["AgentMemoryFragment"],
|
||||||
|
observation: Union[str, StructuredObservation],
|
||||||
|
embeddings: Optional[List[float]] = None,
|
||||||
|
memory_id: Optional[int] = None,
|
||||||
|
importance: Optional[float] = None,
|
||||||
|
is_insight: bool = False,
|
||||||
|
last_accessed_time: Optional[datetime] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "AgentMemoryFragment":
|
||||||
|
"""Build a memory fragment from the given parameters."""
|
||||||
|
if isinstance(observation, str):
|
||||||
|
observation = json.loads(observation)
|
||||||
|
return cls(
|
||||||
|
observation=observation,
|
||||||
|
embeddings=embeddings,
|
||||||
|
memory_id=memory_id,
|
||||||
|
importance=importance,
|
||||||
|
last_accessed_time=last_accessed_time,
|
||||||
|
is_insight=is_insight,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reduce(
|
||||||
|
self, memory_fragments: List["StructuredAgentMemoryFragment"], **kwargs
|
||||||
|
) -> "StructuredAgentMemoryFragment":
|
||||||
|
"""Reduce memory fragments to a single memory fragment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_fragments(List[T]): Memory fragments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: The reduced memory fragment
|
||||||
|
"""
|
||||||
|
if len(memory_fragments) == 0:
|
||||||
|
raise ValueError("Memory fragments is empty.")
|
||||||
|
if len(memory_fragments) == 1:
|
||||||
|
return memory_fragments[0]
|
||||||
|
|
||||||
|
obs = []
|
||||||
|
for memory_fragment in memory_fragments:
|
||||||
|
try:
|
||||||
|
obs.append(json.loads(memory_fragment.raw_observation))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse observation %s: %s",
|
||||||
|
memory_fragment.raw_observation,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return self.current_class.build_from(obs, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory(Memory[AgentMemoryFragment]):
|
class AgentMemory(Memory[AgentMemoryFragment]):
|
||||||
"""Agent memory."""
|
"""Agent memory."""
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from jinja2 import Environment, Template, meta
|
from jinja2 import Environment, Template, meta
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
@ -10,10 +10,17 @@ from jinja2.sandbox import SandboxedEnvironment
|
|||||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from .action.base import ActionOutput
|
from .action.base import ActionOutput
|
||||||
from .memory.agent_memory import AgentMemory, AgentMemoryFragment
|
from .memory.agent_memory import (
|
||||||
|
AgentMemory,
|
||||||
|
AgentMemoryFragment,
|
||||||
|
StructuredAgentMemoryFragment,
|
||||||
|
)
|
||||||
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
|
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
|
||||||
from .profile import Profile, ProfileConfig
|
from .profile import Profile, ProfileConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .agent import AgentMessage
|
||||||
|
|
||||||
|
|
||||||
class AgentRunMode(str, Enum):
|
class AgentRunMode(str, Enum):
|
||||||
"""Agent run mode."""
|
"""Agent run mode."""
|
||||||
@ -210,10 +217,15 @@ class Role(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
|
||||||
|
"""Return the memory fragment class."""
|
||||||
|
return AgentMemoryFragment
|
||||||
|
|
||||||
async def read_memories(
|
async def read_memories(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
) -> str:
|
) -> Union[str, List["AgentMessage"]]:
|
||||||
"""Read the memories from the memory."""
|
"""Read the memories from the memory."""
|
||||||
memories = await self.memory.read(question)
|
memories = await self.memory.read(question)
|
||||||
recent_messages = [m.raw_observation for m in memories]
|
recent_messages = [m.raw_observation for m in memories]
|
||||||
@ -239,6 +251,7 @@ class Role(ABC, BaseModel):
|
|||||||
action_output(ActionOutput): The action output.
|
action_output(ActionOutput): The action output.
|
||||||
check_pass(bool): Whether the check pass.
|
check_pass(bool): Whether the check pass.
|
||||||
check_fail_reason(str): The check fail reason.
|
check_fail_reason(str): The check fail reason.
|
||||||
|
current_retry_counter(int): The current retry counter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentMemoryFragment: The memory fragment created.
|
AgentMemoryFragment: The memory fragment created.
|
||||||
@ -247,17 +260,29 @@ class Role(ABC, BaseModel):
|
|||||||
raise ValueError("Action output is required to save to memory.")
|
raise ValueError("Action output is required to save to memory.")
|
||||||
|
|
||||||
mem_thoughts = action_output.thoughts or ai_message
|
mem_thoughts = action_output.thoughts or ai_message
|
||||||
observation = action_output.observations
|
action = action_output.action
|
||||||
|
action_input = action_output.action_input
|
||||||
|
observation = check_fail_reason or action_output.observations
|
||||||
|
|
||||||
memory_map = {
|
memory_map = {
|
||||||
"question": question,
|
|
||||||
"thought": mem_thoughts,
|
"thought": mem_thoughts,
|
||||||
"action": check_fail_reason,
|
"action": action,
|
||||||
"observation": observation,
|
"observation": observation,
|
||||||
}
|
}
|
||||||
|
if action_input:
|
||||||
|
memory_map["action_input"] = action_input
|
||||||
|
|
||||||
|
if current_retry_counter is not None and current_retry_counter == 0:
|
||||||
|
memory_map["question"] = question
|
||||||
|
|
||||||
write_memory_template = self.write_memory_template
|
write_memory_template = self.write_memory_template
|
||||||
memory_content = self._render_template(write_memory_template, **memory_map)
|
memory_content = self._render_template(write_memory_template, **memory_map)
|
||||||
fragment = AgentMemoryFragment(memory_content)
|
|
||||||
|
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
|
||||||
|
if issubclass(fragment_cls, StructuredAgentMemoryFragment):
|
||||||
|
fragment = fragment_cls(memory_map)
|
||||||
|
else:
|
||||||
|
fragment = fragment_cls(memory_content)
|
||||||
await self.memory.write(fragment)
|
await self.memory.write(fragment)
|
||||||
|
|
||||||
action_output.memory_fragments = {
|
action_output.memory_fragments = {
|
||||||
@ -270,9 +295,10 @@ class Role(ABC, BaseModel):
|
|||||||
async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None:
|
async def recovering_memory(self, action_outputs: List[ActionOutput]) -> None:
|
||||||
"""Recover the memory from the action outputs."""
|
"""Recover the memory from the action outputs."""
|
||||||
fragments = []
|
fragments = []
|
||||||
|
fragment_cls: Type[AgentMemoryFragment] = self.memory_fragment_class
|
||||||
for action_output in action_outputs:
|
for action_output in action_outputs:
|
||||||
if action_output.memory_fragments:
|
if action_output.memory_fragments:
|
||||||
fragment = AgentMemoryFragment.build_from(
|
fragment = fragment_cls.build_from(
|
||||||
observation=action_output.memory_fragments["memory"],
|
observation=action_output.memory_fragments["memory"],
|
||||||
importance=action_output.memory_fragments.get("importance"),
|
importance=action_output.memory_fragments.get("importance"),
|
||||||
memory_id=action_output.memory_fragments.get("id"),
|
memory_id=action_output.memory_fragments.get("id"),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import Field
|
from dbgpt._private.pydantic import Field
|
||||||
from dbgpt.agent import (
|
from dbgpt.agent import (
|
||||||
@ -11,12 +12,14 @@ from dbgpt.agent import (
|
|||||||
ProfileConfig,
|
ProfileConfig,
|
||||||
Resource,
|
Resource,
|
||||||
ResourceType,
|
ResourceType,
|
||||||
|
StructuredAgentMemoryFragment,
|
||||||
)
|
)
|
||||||
from dbgpt.agent.core.role import AgentRunMode
|
from dbgpt.agent.core.role import AgentRunMode
|
||||||
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
|
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
|
||||||
from dbgpt.agent.util.react_parser import ReActOutputParser
|
from dbgpt.agent.util.react_parser import ReActOutputParser
|
||||||
from dbgpt.util.configure import DynConfig
|
from dbgpt.util.configure import DynConfig
|
||||||
|
|
||||||
|
from ...core import ModelMessageRoleType
|
||||||
from .actions.react_action import ReActAction, Terminate
|
from .actions.react_action import ReActAction, Terminate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -63,12 +66,9 @@ Please Solve this task:
|
|||||||
Please answer in the same language as the user's question.
|
Please answer in the same language as the user's question.
|
||||||
The current time is: {{ now_time }}.
|
The current time is: {{ now_time }}.
|
||||||
"""
|
"""
|
||||||
_REACT_USER_TEMPLATE = """\
|
|
||||||
{% if most_recent_memories %}\
|
# Not needed additional user prompt template
|
||||||
Most recent message:
|
_REACT_USER_TEMPLATE = """"""
|
||||||
{{ most_recent_memories }}
|
|
||||||
{% endif %}\
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
_REACT_WRITE_MEMORY_TEMPLATE = """\
|
_REACT_WRITE_MEMORY_TEMPLATE = """\
|
||||||
@ -225,7 +225,10 @@ class ReActAgent(ConversableAgent):
|
|||||||
steps = self.parser.parse(message_content)
|
steps = self.parser.parse(message_content)
|
||||||
err_msg = None
|
err_msg = None
|
||||||
if not steps:
|
if not steps:
|
||||||
err_msg = "No correct response found."
|
err_msg = (
|
||||||
|
"No correct response found. Please check your response, which must"
|
||||||
|
" be in the format indicated in the system prompt."
|
||||||
|
)
|
||||||
elif len(steps) != 1:
|
elif len(steps) != 1:
|
||||||
err_msg = "Only one action is allowed each time."
|
err_msg = "Only one action is allowed each time."
|
||||||
if err_msg:
|
if err_msg:
|
||||||
@ -243,56 +246,72 @@ class ReActAgent(ConversableAgent):
|
|||||||
)
|
)
|
||||||
return action_output
|
return action_output
|
||||||
|
|
||||||
async def write_memories(
|
@property
|
||||||
|
def memory_fragment_class(self) -> Type[AgentMemoryFragment]:
|
||||||
|
"""Return the memory fragment class."""
|
||||||
|
return StructuredAgentMemoryFragment
|
||||||
|
|
||||||
|
async def read_memories(
|
||||||
self,
|
self,
|
||||||
question: str,
|
observation: str,
|
||||||
ai_message: str,
|
) -> Union[str, List["AgentMessage"]]:
|
||||||
action_output: Optional[ActionOutput] = None,
|
memories = await self.memory.read(observation)
|
||||||
check_pass: bool = True,
|
not_json_memories = []
|
||||||
check_fail_reason: Optional[str] = None,
|
messages = []
|
||||||
current_retry_counter: Optional[int] = None,
|
structured_memories = []
|
||||||
) -> AgentMemoryFragment:
|
for m in memories:
|
||||||
"""Write the memories to the memory.
|
if m.raw_observation:
|
||||||
|
try:
|
||||||
|
mem_dict = json.loads(m.raw_observation)
|
||||||
|
if isinstance(mem_dict, dict):
|
||||||
|
structured_memories.append(mem_dict)
|
||||||
|
elif isinstance(mem_dict, list):
|
||||||
|
structured_memories.extend(mem_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid memory format.")
|
||||||
|
except Exception:
|
||||||
|
not_json_memories.append(m.raw_observation)
|
||||||
|
|
||||||
We suggest you to override this method to save the conversation to memory
|
for mem_dict in structured_memories:
|
||||||
according to your needs.
|
question = mem_dict.get("question")
|
||||||
|
thought = mem_dict.get("thought")
|
||||||
Args:
|
action = mem_dict.get("action")
|
||||||
question(str): The question received.
|
action_input = mem_dict.get("action_input")
|
||||||
ai_message(str): The AI message, LLM output.
|
observation = mem_dict.get("observation")
|
||||||
action_output(ActionOutput): The action output.
|
if question:
|
||||||
check_pass(bool): Whether the check pass.
|
messages.append(
|
||||||
check_fail_reason(str): The check fail reason.
|
AgentMessage(
|
||||||
|
content=f"Question: {question}",
|
||||||
Returns:
|
role=ModelMessageRoleType.HUMAN,
|
||||||
AgentMemoryFragment: The memory fragment created.
|
)
|
||||||
"""
|
)
|
||||||
if not action_output:
|
ai_content = []
|
||||||
raise ValueError("Action output is required to save to memory.")
|
if thought:
|
||||||
|
ai_content.append(f"Thought: {thought}")
|
||||||
mem_thoughts = action_output.thoughts or ai_message
|
if action:
|
||||||
action = action_output.action
|
ai_content.append(f"Action: {action}")
|
||||||
action_input = action_output.action_input
|
|
||||||
observation = check_fail_reason or action_output.observations
|
|
||||||
|
|
||||||
memory_map = {
|
|
||||||
"thought": mem_thoughts,
|
|
||||||
"action": action,
|
|
||||||
"observation": observation,
|
|
||||||
}
|
|
||||||
if action_input:
|
if action_input:
|
||||||
memory_map["action_input"] = action_input
|
ai_content.append(f"Action Input: {action_input}")
|
||||||
|
messages.append(
|
||||||
|
AgentMessage(
|
||||||
|
content="\n".join(ai_content),
|
||||||
|
role=ModelMessageRoleType.AI,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if current_retry_counter is not None and current_retry_counter == 0:
|
if observation:
|
||||||
memory_map["question"] = question
|
messages.append(
|
||||||
|
AgentMessage(
|
||||||
|
content=f"Observation: {observation}",
|
||||||
|
role=ModelMessageRoleType.HUMAN,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
write_memory_template = self.write_memory_template
|
if not messages and not_json_memories:
|
||||||
memory_content = self._render_template(write_memory_template, **memory_map)
|
messages.append(
|
||||||
fragment = AgentMemoryFragment(memory_content)
|
AgentMessage(
|
||||||
await self.memory.write(fragment)
|
content="\n".join(not_json_memories),
|
||||||
action_output.memory_fragments = {
|
role=ModelMessageRoleType.HUMAN,
|
||||||
"memory": fragment.raw_observation,
|
)
|
||||||
"id": fragment.id,
|
)
|
||||||
"importance": fragment.importance,
|
return messages
|
||||||
}
|
|
||||||
return fragment
|
|
||||||
|
@ -67,10 +67,12 @@ class StartAppAssistantAgent(ConversableAgent):
|
|||||||
self,
|
self,
|
||||||
received_message: AgentMessage,
|
received_message: AgentMessage,
|
||||||
sender: Agent,
|
sender: Agent,
|
||||||
|
observation: Optional[str] = None,
|
||||||
rely_messages: Optional[List[AgentMessage]] = None,
|
rely_messages: Optional[List[AgentMessage]] = None,
|
||||||
historical_dialogues: Optional[List[AgentMessage]] = None,
|
historical_dialogues: Optional[List[AgentMessage]] = None,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
is_retry_chat: bool = False,
|
is_retry_chat: bool = False,
|
||||||
|
current_retry_counter: Optional[int] = None,
|
||||||
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
) -> Tuple[List[AgentMessage], Optional[Dict]]:
|
||||||
if rely_messages and len(rely_messages) > 0:
|
if rely_messages and len(rely_messages) > 0:
|
||||||
return rely_messages[-1:], None
|
return rely_messages[-1:], None
|
||||||
|
Loading…
Reference in New Issue
Block a user