From 105c787e5af36ded72fe831a462c5f449a16dc15 Mon Sep 17 00:00:00 2001 From: Neil Murphy Date: Thu, 10 Aug 2023 15:45:30 -0700 Subject: [PATCH] =?UTF-8?q?Add=20convenience=20methods=20to=20Conversation?= =?UTF-8?q?BufferMemory=20and=20ConversationB=E2=80=A6=20(#8981)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add convenience methods to `ConversationBufferMemory` and `ConversationBufferWindowMemory` to get buffer either as messages or as string. Helps when `return_messages` is set to `True` but you want access to the messages as a string, and vice versa. @hwchase17 One use case: Using a `MultiPromptRouter` where `default_chain` is `ConversationChain`, but destination chains are `LLMChains`. Injecting chat memory into prompts for destination chains prints a stringified `List[Messages]` in the prompt, which creates a lot of noise. These convenience methods allow caller to choose either as needed. --------- Co-authored-by: Bagatur --- libs/langchain/langchain/memory/buffer.py | 23 ++++++++----- .../langchain/memory/buffer_window.py | 33 +++++++++++-------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index 50b1468b648..7eac112b45b 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -17,14 +17,21 @@ class ConversationBufferMemory(BaseChatMemory): @property def buffer(self) -> Any: """String buffer of memory.""" - if self.return_messages: - return self.chat_memory.messages - else: - return get_buffer_string( - self.chat_memory.messages, - human_prefix=self.human_prefix, - ai_prefix=self.ai_prefix, - ) + return self.buffer_as_messages if self.return_messages else self.buffer_as_str + + @property + def buffer_as_str(self) -> str: + """Exposes the buffer as a string in case return_messages is True.""" + return get_buffer_string( + self.chat_memory.messages, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + + @property + def buffer_as_messages(self) -> List[Any]: + """Exposes the buffer as a list of messages in case return_messages is False.""" + return self.chat_memory.messages @property def memory_variables(self) -> List[str]: diff --git a/libs/langchain/langchain/memory/buffer_window.py b/libs/langchain/langchain/memory/buffer_window.py index af27f41d33e..05b883e6d7d 100644 --- a/libs/langchain/langchain/memory/buffer_window.py +++ b/libs/langchain/langchain/memory/buffer_window.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from langchain.memory.chat_memory import BaseChatMemory from langchain.schema.messages import BaseMessage, get_buffer_string @@ -14,9 +14,24 @@ class ConversationBufferWindowMemory(BaseChatMemory): """Number of messages to store in buffer.""" @property - def buffer(self) -> List[BaseMessage]: + def buffer(self) -> Union[str, List[BaseMessage]]: """String buffer of memory.""" - return self.chat_memory.messages + return self.buffer_as_messages if self.return_messages else self.buffer_as_str + + @property + def buffer_as_str(self) -> str: + """Exposes the buffer as a string in case return_messages is True.""" + messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else [] + return get_buffer_string( + messages, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + + @property + def buffer_as_messages(self) -> List[BaseMessage]: + """Exposes the buffer as a list of messages in case return_messages is False.""" + return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else [] @property def memory_variables(self) -> List[str]: @@ -26,14 +41,6 @@ class ConversationBufferWindowMemory(BaseChatMemory): """ return [self.memory_key] - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" - - buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else [] - if not self.return_messages: - buffer = get_buffer_string( - buffer, - human_prefix=self.human_prefix, - ai_prefix=self.ai_prefix, - ) - return {self.memory_key: buffer} + return {self.memory_key: self.buffer}