Add convenience methods to ConversationBufferMemory and ConversationB… (#8981)

Add convenience methods to `ConversationBufferMemory` and
`ConversationBufferWindowMemory` to get buffer either as messages or as
string.

Helps when `return_messages` is set to `True` but you want access to the
messages as a string, and vice versa.

@hwchase17

One use case: Using a `MultiPromptRouter` where `default_chain` is
`ConversationChain`, but destination chains are `LLMChains`. Injecting
chat memory into prompts for destination chains prints a stringified
`List[Messages]` in the prompt, which creates a lot of noise. These
convenience methods allow caller to choose either as needed.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Neil Murphy 2023-08-10 15:45:30 -07:00 committed by GitHub
parent 6221eb5974
commit 105c787e5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 21 deletions

View File

@ -17,15 +17,22 @@ class ConversationBufferMemory(BaseChatMemory):
@property
def buffer(self) -> Any:
"""String buffer of memory."""
if self.return_messages:
return self.chat_memory.messages
else:
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
return get_buffer_string(
self.chat_memory.messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[Any]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string
@ -14,9 +14,24 @@ class ConversationBufferWindowMemory(BaseChatMemory):
"""Number of messages to store in buffer."""
@property
def buffer(self) -> List[BaseMessage]:
def buffer(self) -> Union[str, List[BaseMessage]]:
"""String buffer of memory."""
return self.chat_memory.messages
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
return get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
@property
def memory_variables(self) -> List[str]:
@ -26,14 +41,6 @@ class ConversationBufferWindowMemory(BaseChatMemory):
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []
if not self.return_messages:
buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
return {self.memory_key: self.buffer}