mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 12:30:14 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
16
dbgpt/agent/core/memory/__init__.py
Normal file
16
dbgpt/agent/core/memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Memory module for the agent."""
|
||||
|
||||
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
InsightMemoryFragment,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
)
|
||||
from .hybrid import HybridMemory # noqa: F401
|
||||
from .llm import LLMImportanceScorer, LLMInsightExtractor # noqa: F401
|
||||
from .long_term import LongTermMemory, LongTermRetriever # noqa: F401
|
||||
from .short_term import EnhancedShortTermMemory # noqa: F401
|
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Agent memory module."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional, Type, cast
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.id_generator import new_id
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
ShortTermMemory,
|
||||
WriteOperation,
|
||||
)
|
||||
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
|
||||
|
||||
|
||||
class AgentMemoryFragment(MemoryFragment):
|
||||
"""Default memory fragment for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
is_insight: bool = False,
|
||||
):
|
||||
"""Create a memory fragment."""
|
||||
if not memory_id:
|
||||
# Generate a new memory id, we use snowflake id generator here.
|
||||
memory_id = new_id()
|
||||
self.observation = observation
|
||||
self._embeddings = embeddings
|
||||
self.memory_id: int = cast(int, memory_id)
|
||||
self._importance: Optional[float] = importance
|
||||
self._last_accessed_time: Optional[datetime] = last_accessed_time
|
||||
self._is_insight = is_insight
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Return the memory id."""
|
||||
return self.memory_id
|
||||
|
||||
@property
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation."""
|
||||
return self.observation
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment."""
|
||||
return self._embeddings
|
||||
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
embeddings = embedding_func([self.observation])
|
||||
return embeddings[0]
|
||||
|
||||
@property
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: Whether the memory fragment is an insight
|
||||
"""
|
||||
return self._is_insight
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[float]: Importance of the memory fragment
|
||||
"""
|
||||
return self._importance
|
||||
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): Importance of the memory fragment
|
||||
|
||||
Returns:
|
||||
Optional[float]: Old importance
|
||||
"""
|
||||
old_importance = self._importance
|
||||
self._importance = importance
|
||||
return old_importance
|
||||
|
||||
@property
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Used to determine the least recently used memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Last accessed time
|
||||
"""
|
||||
return self._last_accessed_time
|
||||
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): Current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Old last accessed time
|
||||
"""
|
||||
old_time = self._last_accessed_time
|
||||
self._last_accessed_time = now
|
||||
return old_time
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
cls: Type["AgentMemoryFragment"],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs
|
||||
) -> "AgentMemoryFragment":
|
||||
"""Build a memory fragment from the given parameters."""
|
||||
return cls(
|
||||
observation=observation,
|
||||
embeddings=embeddings,
|
||||
memory_id=memory_id,
|
||||
importance=importance,
|
||||
last_accessed_time=last_accessed_time,
|
||||
is_insight=is_insight,
|
||||
)
|
||||
|
||||
def copy(self: "AgentMemoryFragment") -> "AgentMemoryFragment":
|
||||
"""Return a copy of the memory fragment."""
|
||||
return AgentMemoryFragment.build_from(
|
||||
observation=self.observation,
|
||||
embeddings=self._embeddings,
|
||||
memory_id=self.memory_id,
|
||||
importance=self.importance,
|
||||
last_accessed_time=self.last_accessed_time,
|
||||
is_insight=self.is_insight,
|
||||
)
|
||||
|
||||
|
||||
class AgentMemory(Memory[AgentMemoryFragment]):
|
||||
"""Agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Optional[Memory[AgentMemoryFragment]] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
gpts_memory: Optional[GptsMemory] = None,
|
||||
):
|
||||
"""Create an agent memory.
|
||||
|
||||
Args:
|
||||
memory(Memory[AgentMemoryFragment]): Memory to store fragments
|
||||
importance_scorer(ImportanceScorer[AgentMemoryFragment]): Scorer to
|
||||
calculate the importance of memory fragments
|
||||
insight_extractor(InsightExtractor[AgentMemoryFragment]): Extractor to
|
||||
extract insights from memory fragments
|
||||
gpts_memory(GptsMemory): Memory to store GPTs related information
|
||||
"""
|
||||
if not memory:
|
||||
memory = ShortTermMemory(buffer_size=5)
|
||||
if not gpts_memory:
|
||||
gpts_memory = GptsMemory()
|
||||
self.memory: Memory[AgentMemoryFragment] = cast(
|
||||
Memory[AgentMemoryFragment], memory
|
||||
)
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self.gpts_memory = gpts_memory
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "AgentMemory", now: Optional[datetime] = None
|
||||
) -> "AgentMemory":
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
The gpst_memory is not cloned, it will be shared in whole agent memory.
|
||||
"""
|
||||
m = AgentMemory(
|
||||
memory=self.memory.structure_clone(now),
|
||||
importance_scorer=self.importance_scorer,
|
||||
insight_extractor=self.insight_extractor,
|
||||
gpts_memory=self.gpts_memory,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
real_memory_fragment_class: Optional[Type[AgentMemoryFragment]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory."""
|
||||
self.memory.initialize(
|
||||
name=name,
|
||||
llm_client=llm_client,
|
||||
importance_scorer=importance_scorer or self.importance_scorer,
|
||||
insight_extractor=insight_extractor or self.insight_extractor,
|
||||
real_memory_fragment_class=real_memory_fragment_class
|
||||
or AgentMemoryFragment,
|
||||
)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: AgentMemoryFragment,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[AgentMemoryFragment]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
return await self.memory.write(memory_fragment, now)
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[AgentMemoryFragment]:
|
||||
"""Read memory fragments related to the observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
alpha(float): Importance weight
|
||||
beta(float): Time weight
|
||||
gamma(float): Randomness weight
|
||||
|
||||
Returns:
|
||||
List[AgentMemoryFragment]: List of memory fragments
|
||||
"""
|
||||
return await self.memory.read(observation, alpha, beta, gamma)
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[AgentMemoryFragment]:
|
||||
"""Clear the memory."""
|
||||
return await self.memory.clear()
|
||||
|
||||
@property
|
||||
def plans_memory(self) -> GptsPlansMemory:
|
||||
"""Return the plan memory."""
|
||||
return self.gpts_memory.plans_memory
|
||||
|
||||
@property
|
||||
def message_memory(self) -> GptsMessageMemory:
|
||||
"""Return the message memory."""
|
||||
return self.gpts_memory.message_memory
|
776
dbgpt/agent/core/memory/base.py
Normal file
776
dbgpt/agent/core/memory/base.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""Memory for agent.
|
||||
|
||||
Human memory follows a general progression from sensory memory that registers
|
||||
perceptual inputs, to short-term memory that maintains information transiently, to
|
||||
long-term memory that consolidates information over extended periods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI, immutable, mutable
|
||||
|
||||
T = TypeVar("T", bound="MemoryFragment")
|
||||
M = TypeVar("M", bound="Memory")
|
||||
|
||||
|
||||
class WriteOperation(str, Enum):
|
||||
"""Write operation."""
|
||||
|
||||
ADD = "add"
|
||||
RETRIEVAL = "retrieval"
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class MemoryFragment(ABC):
|
||||
"""Memory fragment interface.
|
||||
|
||||
It is the interface of memory fragment, which is the basic unit of memory, which
|
||||
contains the basic information of memory, such as observation, importance, whether
|
||||
it is insight, last access time, etc
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def build_from(
|
||||
cls: Type[T],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Build a memory fragment from memory id and observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
embeddings(List[float], optional): Embeddings of the memory fragment.
|
||||
memory_id(int): Memory id
|
||||
importance(float): Importance
|
||||
is_insight(bool): Whether the memory fragment is an insight
|
||||
last_accessed_time(datetime): Last accessed time
|
||||
|
||||
Returns:
|
||||
MemoryFragment: Memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> int:
|
||||
"""Return the id of the memory fragment.
|
||||
|
||||
Commonly, the id is generated by Snowflake algorithm. So we can parse the
|
||||
timestamp of when the memory fragment is created.
|
||||
|
||||
Returns:
|
||||
int: id
|
||||
"""
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Metadata
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
It should be noted that importance only reflects the characters of the memory
|
||||
itself.
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance, None means the importance is not available.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): importance
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation.
|
||||
|
||||
Raw observation is the original observation data, it can be an observation from
|
||||
environment or an observation after executing an action.
|
||||
|
||||
Returns:
|
||||
str: raw observation
|
||||
"""
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: embeddings
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: whether the memory fragment is an insight.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): The current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: The last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the memory fragment."""
|
||||
|
||||
def reduce(self, memory_fragments: List[T], **kwargs) -> T:
|
||||
"""Reduce memory fragments to a single memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
T: The reduced memory fragment
|
||||
"""
|
||||
obs = []
|
||||
for memory_fragment in memory_fragments:
|
||||
obs.append(memory_fragment.raw_observation)
|
||||
new_observation = ";".join(obs)
|
||||
return self.current_class.build_from(new_observation, **kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def current_class(self: T) -> Type[T]:
|
||||
"""Return the current class."""
|
||||
return self.__class__
|
||||
|
||||
|
||||
class InsightMemoryFragment(Generic[T]):
|
||||
"""Insight memory fragment.
|
||||
|
||||
Insight memory fragment is a memory fragment that contains insights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_memory_fragment: Union[T, List[T]],
|
||||
insights: Union[List[T], List[str]],
|
||||
):
|
||||
"""Create an insight memory fragment.
|
||||
|
||||
Insight is also a memory fragment.
|
||||
"""
|
||||
if insights and isinstance(insights[0], str):
|
||||
mf = (
|
||||
original_memory_fragment[0]
|
||||
if isinstance(original_memory_fragment, list)
|
||||
else original_memory_fragment
|
||||
)
|
||||
insights = [
|
||||
mf.current_class.build_from(i, is_insight=True) for i in insights # type: ignore # noqa
|
||||
]
|
||||
self._original_memory_fragment = original_memory_fragment
|
||||
self._insights: List[T] = cast(List[T], insights)
|
||||
|
||||
@property
|
||||
def original_memory_fragment(self) -> Union[T, List[T]]:
|
||||
"""Return the original memory fragment."""
|
||||
return self._original_memory_fragment
|
||||
|
||||
@property
|
||||
def insights(self) -> List[T]:
|
||||
"""Return the insights."""
|
||||
return self._insights
|
||||
|
||||
|
||||
class DiscardedMemoryFragments(Generic[T]):
|
||||
"""Discarded memory fragments.
|
||||
|
||||
Sometimes, we need to discard some memory fragments, there are following cases:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and the
|
||||
same/similar observation from environment is received.
|
||||
2. Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written.
|
||||
3. The memory fragment is not important enough.
|
||||
4. Simulation of forgetting mechanism.
|
||||
|
||||
The discarded memory fragments may be transferred to another memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
discarded_memory_fragments: List[T],
|
||||
discarded_insights: Optional[List[InsightMemoryFragment[T]]] = None,
|
||||
):
|
||||
"""Create a discarded memory fragments."""
|
||||
if discarded_insights is None:
|
||||
discarded_insights = []
|
||||
self._discarded_memory_fragments = discarded_memory_fragments
|
||||
self._discarded_insights = discarded_insights
|
||||
|
||||
@property
|
||||
def discarded_memory_fragments(self) -> List[T]:
|
||||
"""Return the discarded memory fragments."""
|
||||
return self._discarded_memory_fragments
|
||||
|
||||
@property
|
||||
def discarded_insights(self) -> List[InsightMemoryFragment[T]]:
|
||||
"""Return the discarded insights."""
|
||||
return self._discarded_insights
|
||||
|
||||
|
||||
class InsightExtractor(ABC, Generic[T]):
|
||||
"""Insight extractor interface.
|
||||
|
||||
Obtain high-level insights from memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
class ImportanceScorer(ABC, Generic[T]):
|
||||
"""Importance scorer interface.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment.
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class Memory(ABC, Generic[T]):
|
||||
"""Memory interface."""
|
||||
|
||||
name: Optional[str] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
importance_scorer: Optional[ImportanceScorer] = None
|
||||
insight_extractor: Optional[InsightExtractor] = None
|
||||
_real_memory_fragment_class: Optional[Type[T]] = None
|
||||
importance_weight: float = 0.15
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer] = None,
|
||||
insight_extractor: Optional[InsightExtractor] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize memory.
|
||||
|
||||
Some agent may need to initialize memory before using it.
|
||||
"""
|
||||
self.name = name
|
||||
self.llm_client = llm_client
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self._real_memory_fragment_class = real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
def structure_clone(self: M, now: Optional[datetime] = None) -> M:
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
Sometimes, we need to clone the structure of the memory, but not the content.
|
||||
|
||||
There some cases:
|
||||
|
||||
1. When we need to reset the memory, we can use this method to create a new
|
||||
one, and the new memory has the same structure as the old one.
|
||||
2. Create a new agent, the new agent has the same memory structure as the
|
||||
planner.
|
||||
|
||||
Args:
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
M: The structure clone of the memory
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@mutable
|
||||
def _copy_from(self, memory: "Memory") -> None:
|
||||
"""Copy memory from another memory.
|
||||
|
||||
Args:
|
||||
memory(Memory): Another memory
|
||||
"""
|
||||
self.name = memory.name
|
||||
self.llm_client = memory.llm_client
|
||||
self.importance_scorer = memory.importance_scorer
|
||||
self.insight_extractor = memory.insight_extractor
|
||||
self._real_memory_fragment_class = memory._real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to memory.
|
||||
|
||||
Two situations need to be noted here:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and
|
||||
the same/similar observation from environment is received.
|
||||
|
||||
2.Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written to memory, the common strategy is to discard some memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
r"""Read memory fragments by observation.
|
||||
|
||||
Usually, there three commonly used criteria for information extraction, that is,
|
||||
the recency, relevance, and importance
|
||||
|
||||
Memories that are more recent, relevant, and important are more likely to be
|
||||
extracted. Formally, we conclude the following equation from existing
|
||||
literature for memory information extraction:
|
||||
|
||||
.. math::
|
||||
|
||||
m^* = \arg\min_{m \in M} \alpha s^{\text{rec}}(q, m) + \\
|
||||
\beta s^{\text{rel}}(q, m) + \gamma s^{\text{imp}}(m), \tag{1}
|
||||
|
||||
Args:
|
||||
observation(str): observation(Query)
|
||||
alpha(float, optional): Recency coefficient. Default is None.
|
||||
beta(float, optional): Relevance coefficient. Default is None.
|
||||
gamma(float, optional): Importance coefficient. Default is None.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def reflect(self, memory_fragments: List[T]) -> List[T]:
|
||||
"""Reflect memory fragments by observation.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): memory fragments to be reflected.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments after reflection.
|
||||
"""
|
||||
return memory_fragments
|
||||
|
||||
@immutable
|
||||
async def handle_duplicated(
|
||||
self, memory_fragments: List[T], new_memory_fragments: List[T]
|
||||
) -> List[T]:
|
||||
"""Handle duplicated memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
new_memory_fragments(List[T]): New memory fragments
|
||||
|
||||
Returns:
|
||||
List[T]: The new memory fragments after handling duplicated memory
|
||||
fragments.
|
||||
"""
|
||||
return memory_fragments + new_memory_fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments.
|
||||
"""
|
||||
return memory_fragments, []
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments.
|
||||
|
||||
Returns:
|
||||
List[T]: The all cleared memory fragments.
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def get_insights(
|
||||
self, memory_fragments: List[T]
|
||||
) -> List[InsightMemoryFragment[T]]:
|
||||
"""Get insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[InsightMemoryFragment]: The insights of the memory fragments.
|
||||
"""
|
||||
if not self.insight_extractor:
|
||||
return []
|
||||
# Obtain insights in parallel from memory fragments parallel
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.insight_extractor.extract_insights(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
insights = await asyncio.gather(*tasks)
|
||||
result = []
|
||||
for insight in insights:
|
||||
if not insight:
|
||||
continue
|
||||
result.append(insight)
|
||||
if len(result) != len(insights):
|
||||
raise ValueError(
|
||||
"The number of insights is not equal to the number of memory fragments."
|
||||
)
|
||||
return result
|
||||
|
||||
@immutable
|
||||
async def score_memory_importance(self, memory_fragments: List[T]) -> List[float]:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[float]: The importance of memory fragments.
|
||||
"""
|
||||
if not self.importance_scorer:
|
||||
return [5 * self.importance_weight for _ in memory_fragments]
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.importance_scorer.score_importance(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
result = []
|
||||
for importance in await asyncio.gather(*tasks):
|
||||
real_score = importance * self.importance_weight
|
||||
result.append(real_score)
|
||||
return result
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def real_memory_fragment_class(self) -> Type[T]:
|
||||
"""Return the real memory fragment class."""
|
||||
if not self._real_memory_fragment_class:
|
||||
raise ValueError("The real memory fragment class is not set.")
|
||||
return self._real_memory_fragment_class
|
||||
|
||||
|
||||
class SensoryMemory(Memory, Generic[T]):
|
||||
"""Sensory memory."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
threshold_to_short_term: float = 0.1
|
||||
|
||||
def __init__(self, buffer_size: int = 0):
|
||||
"""Create a sensory memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "SensoryMemory[T]", now: Optional[datetime] = None
|
||||
) -> "SensoryMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: SensoryMemory[T] = SensoryMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to sensory memory."""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
discarded_fragments: List[T] = []
|
||||
if len(fragments) > self._buffer_size:
|
||||
fragments, discarded_fragments = await self.handle_overflow(fragments)
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
if not discarded_fragments:
|
||||
return None
|
||||
return DiscardedMemoryFragments(discarded_fragments, [])
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
For sensory memory, the overflow strategy is to transfer all memory fragments
|
||||
to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments, the discarded memory fragments should
|
||||
be transferred to short-term memory.
|
||||
"""
|
||||
scores = await self.score_memory_importance(memory_fragments)
|
||||
result = []
|
||||
for i, memory in enumerate(memory_fragments):
|
||||
if scores[i] >= self.threshold_to_short_term:
|
||||
memory.update_importance(scores[i])
|
||||
result.append(memory)
|
||||
return [], result
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
|
||||
class ShortTermMemory(Memory, Generic[T]):
|
||||
"""Short term memory.
|
||||
|
||||
All memories are stored in computer memory.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = 5):
|
||||
"""Create a short-term memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "ShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "ShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: ShortTermMemory[T] = ShortTermMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): New memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded. The discarded memory fragments
|
||||
should be transferred and stored in long-term memory.
|
||||
"""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
discarded_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
fragments, discarded_fragments = await self.handle_overflow(self._fragments)
|
||||
self._fragments = fragments
|
||||
return discarded_memories
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer the oldest memories to long-term memory.
|
||||
|
||||
This is a very simple strategy, just transfer the oldest memories to long-term
|
||||
memory.
|
||||
"""
|
||||
if len(self._fragments) > self._buffer_size:
|
||||
overflow_cnt = len(self._fragments) - self._buffer_size
|
||||
# Just keep the most recent memories in short-term memory
|
||||
self._fragments = self._fragments[overflow_cnt:]
|
||||
# Transfer the oldest memories to long-term memory
|
||||
overflow_fragments = self._fragments[:overflow_cnt]
|
||||
insights = await self.get_insights(overflow_fragments)
|
||||
return DiscardedMemoryFragments(overflow_fragments, insights)
|
||||
else:
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def short_term_memories(self) -> List[T]:
|
||||
"""Return short-term memories."""
|
||||
return self._fragments
|
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Memory module for GPTS messages and plans.
|
||||
|
||||
It stores the messages and plans generated of multiple agents in the conversation.
|
||||
|
||||
It is different from the agent memory as it is a formatted structure to store the
|
||||
messages and plans, and it can be stored in a database or a file.
|
||||
"""
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
GptsMessage,
|
||||
GptsMessageMemory,
|
||||
GptsPlan,
|
||||
GptsPlansMemory,
|
||||
)
|
||||
from .default_gpts_memory import ( # noqa: F401
|
||||
DefaultGptsMessageMemory,
|
||||
DefaultGptsPlansMemory,
|
||||
)
|
||||
from .gpts_memory import GptsMemory # noqa: F401
|
250
dbgpt/agent/core/memory/gpts/base.py
Normal file
250
dbgpt/agent/core/memory/gpts/base.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Base memory interface for agents."""
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...schema import Status
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GptsPlan:
|
||||
"""Gpts plan."""
|
||||
|
||||
conv_id: str
|
||||
sub_task_num: int
|
||||
sub_task_content: Optional[str]
|
||||
sub_task_title: Optional[str] = None
|
||||
sub_task_agent: Optional[str] = None
|
||||
resource_name: Optional[str] = None
|
||||
rely: Optional[str] = None
|
||||
agent_model: Optional[str] = None
|
||||
retry_times: int = 0
|
||||
max_retry_times: int = 5
|
||||
state: Optional[str] = Status.TODO.value
|
||||
result: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "GptsPlan":
|
||||
"""Create a GptsPlan object from a dictionary."""
|
||||
return GptsPlan(
|
||||
conv_id=d["conv_id"],
|
||||
sub_task_num=d["sub_task_num"],
|
||||
sub_task_content=d["sub_task_content"],
|
||||
sub_task_agent=d["sub_task_agent"],
|
||||
resource_name=d["resource_name"],
|
||||
rely=d["rely"],
|
||||
agent_model=d["agent_model"],
|
||||
retry_times=d["retry_times"],
|
||||
max_retry_times=d["max_retry_times"],
|
||||
state=d["state"],
|
||||
result=d["result"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary representation of the GptsPlan object."""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GptsMessage:
|
||||
"""Gpts message."""
|
||||
|
||||
conv_id: str
|
||||
sender: str
|
||||
|
||||
receiver: str
|
||||
role: str
|
||||
content: str
|
||||
rounds: Optional[int]
|
||||
current_goal: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
review_info: Optional[str] = None
|
||||
action_report: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
created_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "GptsMessage":
|
||||
"""Create a GptsMessage object from a dictionary."""
|
||||
return GptsMessage(
|
||||
conv_id=d["conv_id"],
|
||||
sender=d["sender"],
|
||||
receiver=d["receiver"],
|
||||
role=d["role"],
|
||||
content=d["content"],
|
||||
rounds=d["rounds"],
|
||||
model_name=d["model_name"],
|
||||
current_goal=d["current_goal"],
|
||||
context=d["context"],
|
||||
review_info=d["review_info"],
|
||||
action_report=d["action_report"],
|
||||
created_at=d["created_at"],
|
||||
updated_at=d["updated_at"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary representation of the GptsMessage object."""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
class GptsPlansMemory(ABC):
|
||||
"""Gpts plans memory interface."""
|
||||
|
||||
@abstractmethod
|
||||
def batch_save(self, plans: List[GptsPlan]) -> None:
|
||||
"""Save plans in batch.
|
||||
|
||||
Args:
|
||||
plans: panner generate plans info
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id.
|
||||
|
||||
Args:
|
||||
conv_id: conversation id
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: List[int]
|
||||
) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id and task number.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_nums(List[int]): List of sequence numbers of plans in the same
|
||||
conversation
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get unfinished planning steps.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str) -> None:
|
||||
"""Set the planning step to complete.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_num(int): Planning step num
|
||||
result(str): Plan step results
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update planning step information.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_num(int): Planning step num
|
||||
state(str): the status to update to
|
||||
retry_times(int): Latest number of retries
|
||||
agent(str): Agent's name
|
||||
model(str): Model name
|
||||
result(str): Plan step results
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_by_conv_id(self, conv_id: str) -> None:
|
||||
"""Remove plan by conversation id.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
"""
|
||||
|
||||
|
||||
class GptsMessageMemory(ABC):
|
||||
"""Gpts message memory interface."""
|
||||
|
||||
@abstractmethod
|
||||
def append(self, message: GptsMessage) -> None:
|
||||
"""Add a message.
|
||||
|
||||
Args:
|
||||
message(GptsMessage): Message object
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
"""Return all messages of the agent in the conversation.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
agent(str): Agent's name
|
||||
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_goal: Optional[str] = None,
|
||||
) -> List[GptsMessage]:
|
||||
"""Get messages between two agents.
|
||||
|
||||
Query information related to an agent
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
agent1(str): Agent1's name
|
||||
agent2(str): Agent2's name
|
||||
current_goal(str): Current goal
|
||||
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsMessage]:
|
||||
"""Return all messages in the conversation.
|
||||
|
||||
Query messages by conv id.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
||||
"""Return the last message in the conversation.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
|
||||
Returns:
|
||||
GptsMessage: The last message in the conversation
|
||||
"""
|
149
dbgpt/agent/core/memory/gpts/default_gpts_memory.py
Normal file
149
dbgpt/agent/core/memory/gpts/default_gpts_memory.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Default memory for storing plans and messages."""
|
||||
|
||||
from dataclasses import fields
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ...schema import Status
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
|
||||
|
||||
|
||||
class DefaultGptsPlansMemory(GptsPlansMemory):
|
||||
"""Default memory for storing plans."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a memory to store plans."""
|
||||
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsPlan)])
|
||||
|
||||
def batch_save(self, plans: list[GptsPlan]):
|
||||
"""Save plans in batch."""
|
||||
new_rows = pd.DataFrame([item.to_dict() for item in plans])
|
||||
self.df = pd.concat([self.df, new_rows], ignore_index=True)
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id."""
|
||||
result = self.df.query(f"conv_id==@conv_id") # noqa: F541
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: List[int]
|
||||
) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id and task number."""
|
||||
task_nums_int = [int(num) for num in task_nums] # noqa:F841
|
||||
result = self.df.query( # noqa
|
||||
f"conv_id==@conv_id and sub_task_num in @task_nums_int" # noqa
|
||||
)
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get unfinished planning steps."""
|
||||
todo_states = [Status.TODO.value, Status.RETRYING.value] # noqa: F841
|
||||
result = self.df.query(f"conv_id==@conv_id and state in @todo_states") # noqa
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str):
|
||||
"""Set the planning step to complete."""
|
||||
condition = (self.df["conv_id"] == conv_id) & (
|
||||
self.df["sub_task_num"] == task_num
|
||||
)
|
||||
self.df.loc[condition, "state"] = Status.COMPLETE.value
|
||||
self.df.loc[condition, "result"] = result
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
):
|
||||
"""Update the state of the planning step."""
|
||||
condition = (self.df["conv_id"] == conv_id) & (
|
||||
self.df["sub_task_num"] == task_num
|
||||
)
|
||||
self.df.loc[condition, "state"] = state
|
||||
self.df.loc[condition, "retry_times"] = retry_times
|
||||
self.df.loc[condition, "result"] = result
|
||||
|
||||
if agent:
|
||||
self.df.loc[condition, "sub_task_agent"] = agent
|
||||
|
||||
if model:
|
||||
self.df.loc[condition, "agent_model"] = model
|
||||
|
||||
def remove_by_conv_id(self, conv_id: str):
|
||||
"""Remove all plans in the conversation."""
|
||||
self.df.drop(self.df[self.df["conv_id"] == conv_id].index, inplace=True)
|
||||
|
||||
|
||||
class DefaultGptsMessageMemory(GptsMessageMemory):
|
||||
"""Default memory for storing messages."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a memory to store messages."""
|
||||
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsMessage)])
|
||||
|
||||
def append(self, message: GptsMessage):
|
||||
"""Append a message to the memory."""
|
||||
self.df.loc[len(self.df)] = message.to_dict()
|
||||
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
"""Get all messages sent or received by the agent in the conversation."""
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and (sender==@agent or receiver==@agent)" # noqa: F541
|
||||
)
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_goal: Optional[str] = None,
|
||||
) -> List[GptsMessage]:
|
||||
"""Get all messages between two agents in the conversation."""
|
||||
if current_goal:
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1)) and current_goal==@current_goal" # noqa
|
||||
)
|
||||
else:
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1))" # noqa
|
||||
)
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsMessage]:
|
||||
"""Get all messages in the conversation."""
|
||||
result = self.df.query(f"conv_id==@conv_id") # noqa: F541
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
||||
"""Get the last message in the conversation."""
|
||||
return None
|
190
dbgpt/agent/core/memory/gpts/gpts_memory.py
Normal file
190
dbgpt/agent/core/memory/gpts/gpts_memory.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""GPTs memory."""
|
||||
|
||||
import json
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.vis.client import VisAgentMessages, VisAgentPlans, vis_client
|
||||
|
||||
from ...action.base import ActionOutput
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlansMemory
|
||||
from .default_gpts_memory import DefaultGptsMessageMemory, DefaultGptsPlansMemory
|
||||
|
||||
NONE_GOAL_PREFIX: str = "none_goal_count_"
|
||||
|
||||
|
||||
class GptsMemory:
|
||||
"""GPTs memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plans_memory: Optional[GptsPlansMemory] = None,
|
||||
message_memory: Optional[GptsMessageMemory] = None,
|
||||
):
|
||||
"""Create a memory to store plans and messages."""
|
||||
self._plans_memory: GptsPlansMemory = (
|
||||
plans_memory if plans_memory is not None else DefaultGptsPlansMemory()
|
||||
)
|
||||
self._message_memory: GptsMessageMemory = (
|
||||
message_memory if message_memory is not None else DefaultGptsMessageMemory()
|
||||
)
|
||||
|
||||
@property
|
||||
def plans_memory(self) -> GptsPlansMemory:
|
||||
"""Return the plans memory."""
|
||||
return self._plans_memory
|
||||
|
||||
@property
|
||||
def message_memory(self) -> GptsMessageMemory:
|
||||
"""Return the message memory."""
|
||||
return self._message_memory
|
||||
|
||||
async def _message_group_vis_build(self, message_group):
|
||||
if not message_group:
|
||||
return ""
|
||||
num: int = 0
|
||||
last_goal = next(reversed(message_group))
|
||||
last_goal_messages = message_group[last_goal]
|
||||
|
||||
last_goal_message = last_goal_messages[-1]
|
||||
vis_items = []
|
||||
|
||||
plan_temps = []
|
||||
for key, value in message_group.items():
|
||||
num = num + 1
|
||||
if key.startswith(NONE_GOAL_PREFIX):
|
||||
vis_items.append(await self._messages_to_plan_vis(plan_temps))
|
||||
plan_temps = []
|
||||
num = 0
|
||||
vis_items.append(await self._messages_to_agents_vis(value))
|
||||
else:
|
||||
num += 1
|
||||
plan_temps.append(
|
||||
{
|
||||
"name": key,
|
||||
"num": num,
|
||||
"status": "complete",
|
||||
"agent": value[0].receiver if value else "",
|
||||
"markdown": await self._messages_to_agents_vis(value),
|
||||
}
|
||||
)
|
||||
|
||||
if len(plan_temps) > 0:
|
||||
vis_items.append(await self._messages_to_plan_vis(plan_temps))
|
||||
vis_items.append(await self._messages_to_agents_vis([last_goal_message]))
|
||||
return "\n".join(vis_items)
|
||||
|
||||
async def _plan_vis_build(self, plan_group: dict[str, list]):
|
||||
num: int = 0
|
||||
plan_items = []
|
||||
for key, value in plan_group.items():
|
||||
num = num + 1
|
||||
plan_items.append(
|
||||
{
|
||||
"name": key,
|
||||
"num": num,
|
||||
"status": "complete",
|
||||
"agent": value[0].receiver if value else "",
|
||||
"markdown": await self._messages_to_agents_vis(value),
|
||||
}
|
||||
)
|
||||
return await self._messages_to_plan_vis(plan_items)
|
||||
|
||||
async def one_chat_completions_v2(self, conv_id: str):
|
||||
"""Generate a visualization of the conversation."""
|
||||
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
|
||||
temp_group: Dict[str, List[GptsMessage]] = OrderedDict()
|
||||
none_goal_count = 1
|
||||
count: int = 0
|
||||
for message in messages:
|
||||
count = count + 1
|
||||
if count == 1:
|
||||
continue
|
||||
current_goal = message.current_goal
|
||||
|
||||
last_goal = next(reversed(temp_group)) if temp_group else None
|
||||
if last_goal:
|
||||
last_goal_messages = temp_group[last_goal]
|
||||
if current_goal:
|
||||
if current_goal == last_goal:
|
||||
last_goal_messages.append(message)
|
||||
else:
|
||||
temp_group[current_goal] = [message]
|
||||
else:
|
||||
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
|
||||
none_goal_count += 1
|
||||
else:
|
||||
if current_goal:
|
||||
temp_group[current_goal] = [message]
|
||||
else:
|
||||
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
|
||||
none_goal_count += 1
|
||||
|
||||
return await self._message_group_vis_build(temp_group)
|
||||
|
||||
async def one_chat_completions(self, conv_id: str):
|
||||
"""Generate a visualization of the conversation."""
|
||||
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
|
||||
temp_group: Dict[str, List[GptsMessage]] = defaultdict(list)
|
||||
temp_messages = []
|
||||
vis_items = []
|
||||
count: int = 0
|
||||
for message in messages:
|
||||
count = count + 1
|
||||
if count == 1:
|
||||
continue
|
||||
if not message.current_goal or len(message.current_goal) <= 0:
|
||||
if len(temp_group) > 0:
|
||||
vis_items.append(await self._plan_vis_build(temp_group))
|
||||
temp_group.clear()
|
||||
|
||||
temp_messages.append(message)
|
||||
else:
|
||||
if len(temp_messages) > 0:
|
||||
vis_items.append(await self._messages_to_agents_vis(temp_messages))
|
||||
temp_messages.clear()
|
||||
|
||||
last_goal = message.current_goal
|
||||
temp_group[last_goal].append(message)
|
||||
|
||||
if len(temp_group) > 0:
|
||||
vis_items.append(await self._plan_vis_build(temp_group))
|
||||
temp_group.clear()
|
||||
if len(temp_messages) > 0:
|
||||
vis_items.append(await self._messages_to_agents_vis(temp_messages, True))
|
||||
temp_messages.clear()
|
||||
|
||||
return "\n".join(vis_items)
|
||||
|
||||
async def _messages_to_agents_vis(
|
||||
self, messages: List[GptsMessage], is_last_message: bool = False
|
||||
):
|
||||
if messages is None or len(messages) <= 0:
|
||||
return ""
|
||||
messages_view = []
|
||||
for message in messages:
|
||||
action_report_str = message.action_report
|
||||
view_info = message.content
|
||||
if action_report_str and len(action_report_str) > 0:
|
||||
action_out = ActionOutput.from_dict(json.loads(action_report_str))
|
||||
if action_out is not None and (
|
||||
action_out.is_exe_success or is_last_message
|
||||
):
|
||||
view = action_out.view
|
||||
view_info = view if view else action_out.content
|
||||
|
||||
messages_view.append(
|
||||
{
|
||||
"sender": message.sender,
|
||||
"receiver": message.receiver,
|
||||
"model": message.model_name,
|
||||
"markdown": view_info,
|
||||
}
|
||||
)
|
||||
vis_compent = vis_client.get(VisAgentMessages.vis_tag())
|
||||
return await vis_compent.display(content=messages_view)
|
||||
|
||||
async def _messages_to_plan_vis(self, messages: List[Dict]):
|
||||
if messages is None or len(messages) <= 0:
|
||||
return ""
|
||||
return await vis_client.get(VisAgentPlans.vis_tag()).display(content=messages)
|
288
dbgpt/agent/core/memory/hybrid.py
Normal file
288
dbgpt/agent/core/memory/hybrid.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hybrid memory module.
|
||||
|
||||
This structure explicitly models the human short-term and long-term memories. The
|
||||
short-term memory temporarily buffers recent perceptions, while long-term memory
|
||||
consolidates important information over time.
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Generic, List, Optional, Tuple, Type
|
||||
|
||||
from dbgpt.core import Embeddings, LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
from .long_term import LongTermMemory
|
||||
from .short_term import EnhancedShortTermMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class HybridMemory(Memory, Generic[T]):
|
||||
"""Hybrid memory for the agent."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
now: datetime,
|
||||
sensory_memory: SensoryMemory[T],
|
||||
short_term_memory: ShortTermMemory[T],
|
||||
long_term_memory: LongTermMemory[T],
|
||||
default_insight_extractor: Optional[InsightExtractor] = None,
|
||||
default_importance_scorer: Optional[ImportanceScorer] = None,
|
||||
):
|
||||
"""Create a hybrid memory."""
|
||||
self.now = now
|
||||
self._sensory_memory = sensory_memory
|
||||
self._short_term_memory = short_term_memory
|
||||
self._long_term_memory = long_term_memory
|
||||
self._default_insight_extractor = default_insight_extractor
|
||||
self._default_importance_scorer = default_importance_scorer
|
||||
|
||||
def structure_clone(
|
||||
self: "HybridMemory[T]", now: Optional[datetime] = None
|
||||
) -> "HybridMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
now = now or self.now
|
||||
m = HybridMemory(
|
||||
now=now,
|
||||
sensory_memory=self._sensory_memory.structure_clone(now),
|
||||
short_term_memory=self._short_term_memory.structure_clone(now),
|
||||
long_term_memory=self._long_term_memory.structure_clone(now),
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def from_chroma(
|
||||
cls,
|
||||
vstore_name: Optional[str] = "_chroma_agent_memory_",
|
||||
vstore_path: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from Chroma vector store."""
|
||||
from dbgpt.configs.model_config import DATA_DIR
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
if not embeddings:
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.openai()
|
||||
|
||||
vstore_path = vstore_path or os.path.join(DATA_DIR, "agent_memory")
|
||||
|
||||
vector_store_connector = VectorStoreConnector.from_default(
|
||||
vector_store_type="Chroma",
|
||||
embedding_fn=embeddings,
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name=vstore_name,
|
||||
persist_path=vstore_path,
|
||||
),
|
||||
)
|
||||
return cls.from_vstore(
|
||||
vector_store_connector=vector_store_connector,
|
||||
embeddings=embeddings,
|
||||
executor=executor,
|
||||
now=now,
|
||||
sensory_memory=sensory_memory,
|
||||
short_term_memory=short_term_memory,
|
||||
long_term_memory=long_term_memory,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_vstore(
|
||||
cls,
|
||||
vector_store_connector: "VectorStoreConnector",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from vector store."""
|
||||
if not embeddings:
|
||||
embeddings = vector_store_connector.current_embeddings
|
||||
if not executor:
|
||||
executor = ThreadPoolExecutor()
|
||||
if not now:
|
||||
now = datetime.now()
|
||||
|
||||
if not sensory_memory:
|
||||
sensory_memory = SensoryMemory()
|
||||
if not short_term_memory:
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings is required.")
|
||||
short_term_memory = EnhancedShortTermMemory(embeddings, executor)
|
||||
if not long_term_memory:
|
||||
long_term_memory = LongTermMemory(
|
||||
executor,
|
||||
vector_store_connector,
|
||||
now=now,
|
||||
)
|
||||
return cls(now, sensory_memory, short_term_memory, long_term_memory, **kwargs)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[T]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[T]] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory.
|
||||
|
||||
It will initialize all the memories.
|
||||
"""
|
||||
memories = [
|
||||
self._sensory_memory,
|
||||
self._short_term_memory,
|
||||
self._long_term_memory,
|
||||
]
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"llm_client": llm_client,
|
||||
"importance_scorer": importance_scorer,
|
||||
"insight_extractor": insight_extractor,
|
||||
"real_memory_fragment_class": real_memory_fragment_class,
|
||||
}
|
||||
for memory in memories:
|
||||
memory.initialize(**kwargs)
|
||||
super().initialize(**kwargs)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
# First write to sensory memory
|
||||
sen_discarded_memories = await self._sensory_memory.write(memory_fragment)
|
||||
if not sen_discarded_memories:
|
||||
return None
|
||||
short_term_discarded_memories = []
|
||||
discarded_memory_fragments = []
|
||||
discarded_insights = []
|
||||
for sen_memory in sen_discarded_memories.discarded_memory_fragments:
|
||||
# Write to short term memory
|
||||
short_discarded_memory = await self._short_term_memory.write(sen_memory)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for insight in short_discarded_memory.discarded_insights:
|
||||
# Just keep the first insight
|
||||
discarded_insights.append(insight.insights[0])
|
||||
# Obtain the importance of insights
|
||||
insight_scores = await self.score_memory_importance(discarded_insights)
|
||||
# Get the importance of insights
|
||||
for i, ins in enumerate(discarded_insights):
|
||||
ins.update_importance(insight_scores[i])
|
||||
all_memories = discarded_memory_fragments + discarded_insights
|
||||
if self._long_term_memory:
|
||||
# Write to long term memory
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memories from the memory."""
|
||||
(
|
||||
retrieved_long_term_memories,
|
||||
short_term_discarded_memories,
|
||||
) = await self.fetch_memories(observation, self._short_term_memory)
|
||||
|
||||
await self.save_memories_after_retrieval(short_term_discarded_memories)
|
||||
return retrieved_long_term_memories
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self,
|
||||
observation: str,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
) -> Tuple[List[T], List[DiscardedMemoryFragments[T]]]:
|
||||
"""Fetch memories from long term memory.
|
||||
|
||||
If short_term_memory is provided, write the fetched memories to the short term
|
||||
memory.
|
||||
"""
|
||||
retrieved_long_term_memories = await self._long_term_memory.fetch_memories(
|
||||
observation
|
||||
)
|
||||
if not short_term_memory:
|
||||
return retrieved_long_term_memories, []
|
||||
short_term_discarded_memories: List[DiscardedMemoryFragments[T]] = []
|
||||
discarded_memory_fragments: List[T] = []
|
||||
for ltm in retrieved_long_term_memories:
|
||||
short_discarded_memory = await short_term_memory.write(
|
||||
ltm, op=WriteOperation.RETRIEVAL
|
||||
)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for stm in short_term_memory.short_term_memories:
|
||||
retrieved_long_term_memories.append(
|
||||
stm.current_class.build_from(
|
||||
observation=stm.raw_observation,
|
||||
importance=stm.importance,
|
||||
)
|
||||
)
|
||||
return retrieved_long_term_memories, short_term_discarded_memories
|
||||
|
||||
async def save_memories_after_retrieval(
|
||||
self, fragments: List[DiscardedMemoryFragments[T]]
|
||||
):
|
||||
"""Save memories after retrieval."""
|
||||
discarded_memory_fragments = []
|
||||
discarded_memory_insights: List[T] = []
|
||||
for f in fragments:
|
||||
discarded_memory_fragments.extend(f.discarded_memory_fragments)
|
||||
for fi in f.discarded_insights:
|
||||
discarded_memory_insights.append(fi.insights[0])
|
||||
insights_importance = await self.score_memory_importance(
|
||||
discarded_memory_insights
|
||||
)
|
||||
for i, ins in enumerate(discarded_memory_insights):
|
||||
ins.update_importance(insights_importance[i])
|
||||
all_memories = discarded_memory_fragments + discarded_memory_insights
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
# TODO
|
||||
"""
|
||||
return []
|
174
dbgpt/agent/core/memory/llm.py
Normal file
174
dbgpt/agent/core/memory/llm.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""LLM Utility For Agent Memory."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
)
|
||||
|
||||
from .base import ImportanceScorer, InsightExtractor, InsightMemoryFragment, T
|
||||
|
||||
|
||||
class BaseLLMCaller(BaseModel):
|
||||
"""Base class for LLM caller."""
|
||||
|
||||
prompt: str = ""
|
||||
model: Optional[str] = None
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompt: Union[ChatPromptTemplate, str],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Call LLM client to generate response.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
prompt(ChatPromptTemplate): prompt
|
||||
**kwargs: other keyword arguments
|
||||
|
||||
Returns:
|
||||
str: response
|
||||
"""
|
||||
if not llm_client:
|
||||
raise ValueError("LLM client is required.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||
)
|
||||
model = self.model
|
||||
if not model:
|
||||
model = await self.get_model(llm_client)
|
||||
prompt_kwargs = {}
|
||||
prompt_kwargs.update(kwargs)
|
||||
pass_kwargs = {
|
||||
k: v for k, v in prompt_kwargs.items() if k in prompt.input_variables
|
||||
}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
model_request = ModelRequest.build_request(model, messages=model_messages)
|
||||
model_output = await llm_client.generate(model_request)
|
||||
if not model_output.success:
|
||||
raise ValueError("Call LLM failed.")
|
||||
return model_output.text
|
||||
|
||||
async def get_model(self, llm_client: LLMClient) -> str:
|
||||
"""Get the model.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
|
||||
Returns:
|
||||
str: model
|
||||
"""
|
||||
models = await llm_client.models()
|
||||
if not models:
|
||||
raise ValueError("No models available.")
|
||||
self.model = models[0].model
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings.
|
||||
|
||||
1. First, split by newline
|
||||
2. Remove whitespace from each line
|
||||
"""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
lines = [line for line in lines if line.strip()] # remove empty lines
|
||||
# Use regular expression to remove the numbers and dots at the beginning of
|
||||
# each line
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
@staticmethod
|
||||
def _parse_number(text: str, importance_weight: Optional[float] = None) -> float:
|
||||
"""Parse a number from a string."""
|
||||
match = re.search(r"^\D*(\d+)", text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if importance_weight is not None:
|
||||
score = (score / 10) * importance_weight
|
||||
return score
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
class LLMInsightExtractor(BaseLLMCaller, InsightExtractor[T]):
|
||||
"""LLM Insight Extractor.
|
||||
|
||||
Get high-level insights from memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"There are some memories: {content}\nCan you infer from the "
|
||||
"above memories the high-level insight for this person's character? The insight"
|
||||
" needs to be significantly different from the content and structure of the "
|
||||
"original memories.Respond in one sentence.\n\n"
|
||||
"Results:"
|
||||
)
|
||||
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
insights_str: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
insights_list = self._parse_list(insights_str)
|
||||
return InsightMemoryFragment(memory_fragment, insights_list)
|
||||
|
||||
|
||||
class LLMImportanceScorer(BaseLLMCaller, ImportanceScorer[T]):
|
||||
"""LLM Importance Scorer.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"Please give an importance score between 1 to 10 for the following "
|
||||
"observation. Higher score indicates the observation is more important. More "
|
||||
"rules that should be followed are:"
|
||||
"\n(1): Learning experience of a certain skill is important"
|
||||
"\n(2): The occurrence of a particular event is important"
|
||||
"\n(3): User thoughts and emotions matter"
|
||||
"\n(4): More informative indicates more important."
|
||||
"Please respond with a single integer."
|
||||
"\nObservation:{content}"
|
||||
"\nRating:"
|
||||
)
|
||||
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance score of the memory fragment.
|
||||
"""
|
||||
score: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
return self._parse_number(score)
|
192
dbgpt/agent/core/memory/long_term.py
Normal file
192
dbgpt/agent/core/memory/long_term.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Long-term memory module."""
|
||||
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Generic, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.time_weighted import TimeWeightedEmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from .base import DiscardedMemoryFragments, Memory, T, WriteOperation
|
||||
|
||||
_FORGET_PLACEHOLDER = "[FORGET]"
|
||||
_MERGE_PLACEHOLDER = "[MERGE]"
|
||||
_METADATA_BUFFER_IDX = "buffer_idx"
|
||||
_METADATA_LAST_ACCESSED_AT = "last_accessed_at"
|
||||
_METADAT_IMPORTANCE = "importance"
|
||||
|
||||
|
||||
class LongTermRetriever(TimeWeightedEmbeddingRetriever):
|
||||
"""Long-term retriever."""
|
||||
|
||||
def __init__(self, now: datetime, **kwargs):
|
||||
"""Create a long-term retriever."""
|
||||
self.now = now
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@mutable
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve memories."""
|
||||
current_time = self.now
|
||||
docs_and_scores = {
|
||||
doc.metadata[_METADATA_BUFFER_IDX]: (doc, self.default_salience)
|
||||
# Calculate for all memories.
|
||||
for doc in self.memory_stream
|
||||
}
|
||||
# If a doc is considered salient, update the salience score
|
||||
docs_and_scores.update(self.get_salient_docs(query))
|
||||
rescored_docs = [
|
||||
(doc, self._get_combined_score(doc, relevance, current_time))
|
||||
for doc, relevance in docs_and_scores.values()
|
||||
]
|
||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
result = []
|
||||
# Ensure frequently accessed memories aren't forgotten
|
||||
retrieved_num = 0
|
||||
for doc, _ in rescored_docs:
|
||||
if (
|
||||
retrieved_num < self._k
|
||||
and doc.content.find(_FORGET_PLACEHOLDER) == -1
|
||||
and doc.content.find(_MERGE_PLACEHOLDER) == -1
|
||||
):
|
||||
retrieved_num += 1
|
||||
buffered_doc = self.memory_stream[doc.metadata[_METADATA_BUFFER_IDX]]
|
||||
buffered_doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
|
||||
class LongTermMemory(Memory, Generic[T]):
|
||||
"""Long-term memory."""
|
||||
|
||||
importance_weight: float = 0.15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: Executor,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
now: Optional[datetime] = None,
|
||||
reflection_threshold: Optional[float] = None,
|
||||
):
|
||||
"""Create a long-term memory."""
|
||||
self.now = now or datetime.now()
|
||||
self.executor = executor
|
||||
self.reflecting: bool = False
|
||||
self.forgetting: bool = False
|
||||
self.reflection_threshold: Optional[float] = reflection_threshold
|
||||
self.aggregate_importance: float = 0.0
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self.memory_retriever = LongTermRetriever(
|
||||
now=self.now, vector_store_connector=vector_store_connector
|
||||
)
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "LongTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "LongTermMemory[T]":
|
||||
"""Create a structure clone of the long-term memory."""
|
||||
new_name = self.name
|
||||
if not new_name:
|
||||
raise ValueError("name is required.")
|
||||
m: LongTermMemory[T] = LongTermMemory(
|
||||
now=now,
|
||||
executor=self.executor,
|
||||
vector_store_connector=self._vector_store_connector.new_connector(new_name),
|
||||
reflection_threshold=self.reflection_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
importance = memory_fragment.importance
|
||||
last_accessed_time = memory_fragment.last_accessed_time
|
||||
if importance is None:
|
||||
raise ValueError("importance is required.")
|
||||
if not self.reflecting:
|
||||
self.aggregate_importance += importance
|
||||
|
||||
memory_idx = len(self.memory_retriever.memory_stream)
|
||||
document = Chunk(
|
||||
page_content="[{}] ".format(memory_idx)
|
||||
+ str(memory_fragment.raw_observation),
|
||||
metadata={
|
||||
_METADAT_IMPORTANCE: importance,
|
||||
_METADATA_LAST_ACCESSED_AT: last_accessed_time,
|
||||
},
|
||||
)
|
||||
await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.load_document,
|
||||
[document],
|
||||
current_time=now,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to the memory."""
|
||||
current_datetime = self.now
|
||||
if not now:
|
||||
raise ValueError("Now time is required.")
|
||||
for short_term_memory in memory_fragments:
|
||||
short_term_memory.update_accessed_time(now=now)
|
||||
await self.write(short_term_memory, now=current_datetime)
|
||||
# TODO(fangyinc): Reflect on the memories and get high-level insights.
|
||||
# TODO(fangyinc): Forget memories that are not important.
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments related to the observation."""
|
||||
return await self.fetch_memories(observation=observation, now=self.now)
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> List[T]:
|
||||
"""Fetch memories related to the observation."""
|
||||
# TODO: Mock now?
|
||||
retrieved_memories = []
|
||||
retrieved_list = await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.retrieve,
|
||||
observation,
|
||||
)
|
||||
for retrieved_chunk in retrieved_list:
|
||||
retrieved_memories.append(
|
||||
self.real_memory_fragment_class.build_from(
|
||||
observation=retrieved_chunk.content,
|
||||
importance=retrieved_chunk.metadata[_METADAT_IMPORTANCE],
|
||||
)
|
||||
)
|
||||
return retrieved_memories
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
return []
|
203
dbgpt/agent/core/memory/short_term.py
Normal file
203
dbgpt/agent/core/memory/short_term.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Short term memory module."""
|
||||
|
||||
import random
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.similarity_util import cosine_similarity, sigmoid_function
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
InsightMemoryFragment,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
|
||||
|
||||
class EnhancedShortTermMemory(ShortTermMemory[T]):
|
||||
"""Enhanced short term memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
executor: Executor,
|
||||
buffer_size: int = 2,
|
||||
enhance_similarity_threshold: float = 0.7,
|
||||
enhance_threshold: int = 3,
|
||||
):
|
||||
"""Initialize enhanced short term memory."""
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self._executor = executor
|
||||
self._embeddings = embeddings
|
||||
self.short_embeddings: List[List[float]] = []
|
||||
self.enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
self.enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
self.enhance_similarity_threshold = enhance_similarity_threshold
|
||||
self.enhance_threshold = enhance_threshold
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "EnhancedShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "EnhancedShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: EnhancedShortTermMemory[T] = EnhancedShortTermMemory(
|
||||
embeddings=self._embeddings,
|
||||
executor=self._executor,
|
||||
buffer_size=self._buffer_size,
|
||||
enhance_similarity_threshold=self.enhance_similarity_threshold,
|
||||
enhance_threshold=self.enhance_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write memory fragment to short term memory.
|
||||
|
||||
Reference: https://github.com/RUC-GSAI/YuLan-Rec/blob/main/agents/recagent_memory.py#L336 # noqa
|
||||
"""
|
||||
# Calculate current embeddings of current memory fragment
|
||||
memory_fragment_embeddings = await blocking_func_to_async(
|
||||
self._executor,
|
||||
memory_fragment.calculate_current_embeddings,
|
||||
self._embeddings.embed_documents,
|
||||
)
|
||||
memory_fragment.update_embeddings(memory_fragment_embeddings)
|
||||
for idx, memory_embedding in enumerate(self.short_embeddings):
|
||||
similarity = await blocking_func_to_async(
|
||||
self._executor,
|
||||
cosine_similarity,
|
||||
memory_embedding,
|
||||
memory_fragment_embeddings,
|
||||
)
|
||||
# Sigmoid probability, transform similarity to [0, 1]
|
||||
sigmoid_prob: float = await blocking_func_to_async(
|
||||
self._executor, sigmoid_function, similarity
|
||||
)
|
||||
if (
|
||||
sigmoid_prob >= self.enhance_similarity_threshold
|
||||
and random.random() < sigmoid_prob
|
||||
):
|
||||
self.enhance_cnt[idx] += 1
|
||||
self.enhance_memories[idx].append(memory_fragment)
|
||||
discard_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
if op == WriteOperation.ADD:
|
||||
self._fragments.append(memory_fragment)
|
||||
self.short_embeddings.append(memory_fragment_embeddings)
|
||||
await self.handle_overflow(self._fragments)
|
||||
return discard_memories
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer memory fragment to long term memory."""
|
||||
transfer_flag = False
|
||||
existing_memory = [True for _ in range(len(self.short_term_memories))]
|
||||
|
||||
enhance_memories: List[T] = []
|
||||
to_get_insight_memories: List[T] = []
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
# if exceed the enhancement threshold
|
||||
if (
|
||||
self.enhance_cnt[idx] >= self.enhance_threshold
|
||||
and existing_memory[idx] is True
|
||||
):
|
||||
existing_memory[idx] = False
|
||||
transfer_flag = True
|
||||
#
|
||||
# short-term memories
|
||||
content = [memory]
|
||||
# do not repeatedly add observation memory to summary, so use [:-1].
|
||||
for enhance_memory in self.enhance_memories[idx][:-1]:
|
||||
content.append(enhance_memory)
|
||||
content.append(memory_fragment)
|
||||
# Merge the enhanced memories to single memory
|
||||
merged_enhance_memory: T = memory.reduce(
|
||||
content, merged_memory=memory.importance
|
||||
)
|
||||
to_get_insight_memories.append(merged_enhance_memory)
|
||||
enhance_memories.append(merged_enhance_memory)
|
||||
# Get insights for the every enhanced memory
|
||||
enhance_insights: List[InsightMemoryFragment] = await self.get_insights(
|
||||
to_get_insight_memories
|
||||
)
|
||||
|
||||
if transfer_flag:
|
||||
# re-construct the indexes of short-term memories after removing summarized
|
||||
# memories
|
||||
new_memories: List[T] = []
|
||||
new_embeddings: List[List[float]] = []
|
||||
new_enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
new_enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
if existing_memory[idx]:
|
||||
# Remove not enhanced memories to new memories
|
||||
new_enhance_memories[len(new_memories)] = self.enhance_memories[idx]
|
||||
new_enhance_cnt[len(new_memories)] = self.enhance_cnt[idx]
|
||||
new_memories.append(memory)
|
||||
new_embeddings.append(self.short_embeddings[idx])
|
||||
self._fragments = new_memories
|
||||
self.short_embeddings = new_embeddings
|
||||
self.enhance_memories = new_enhance_memories
|
||||
self.enhance_cnt = new_enhance_cnt
|
||||
return DiscardedMemoryFragments(enhance_memories, enhance_insights)
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle overflow of short term memory.
|
||||
|
||||
Discard the least important memory fragment if the buffer size exceeds.
|
||||
"""
|
||||
if len(self.short_term_memories) > self._buffer_size:
|
||||
id2fragments: Dict[int, Dict] = {}
|
||||
for idx in range(len(self.short_term_memories) - 1):
|
||||
# Not discard the last one
|
||||
memory = self.short_term_memories[idx]
|
||||
id2fragments[idx] = {
|
||||
"enhance_count": self.enhance_cnt[idx],
|
||||
"importance": memory.importance,
|
||||
}
|
||||
# Sort by importance and enhance count, first discard the least important
|
||||
sorted_ids = sorted(
|
||||
id2fragments.keys(),
|
||||
key=lambda x: (
|
||||
id2fragments[x]["importance"],
|
||||
id2fragments[x]["enhance_count"],
|
||||
),
|
||||
)
|
||||
pop_id = sorted_ids[0]
|
||||
pop_raw_observation = self.short_term_memories[pop_id].raw_observation
|
||||
self.enhance_cnt.pop(pop_id)
|
||||
self.enhance_cnt.append(0)
|
||||
self.enhance_memories.pop(pop_id)
|
||||
self.enhance_memories.append([])
|
||||
|
||||
discard_memory = self._fragments.pop(pop_id)
|
||||
self.short_embeddings.pop(pop_id)
|
||||
|
||||
# remove the discard_memory from other short-term memory's enhanced list
|
||||
for idx in range(len(self.short_term_memories)):
|
||||
current_enhance_memories: List[T] = self.enhance_memories[idx]
|
||||
to_remove_idx = []
|
||||
for i, ehf in enumerate(current_enhance_memories):
|
||||
if ehf.raw_observation == pop_raw_observation:
|
||||
to_remove_idx.append(i)
|
||||
for i in to_remove_idx:
|
||||
current_enhance_memories.pop(i)
|
||||
self.enhance_cnt[idx] -= len(to_remove_idx)
|
||||
|
||||
return memory_fragments, [discard_memory]
|
||||
return memory_fragments, []
|
Reference in New Issue
Block a user