mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Add convenience methods to ConversationBufferMemory and ConversationB… (#8981)
Add convenience methods to `ConversationBufferMemory` and `ConversationBufferWindowMemory` to get buffer either as messages or as string. Helps when `return_messages` is set to `True` but you want access to the messages as a string, and vice versa. @hwchase17 One use case: Using a `MultiPromptRouter` where `default_chain` is `ConversationChain`, but destination chains are `LLMChains`. Injecting chat memory into prompts for destination chains prints a stringified `List[Messages]` in the prompt, which creates a lot of noise. These convenience methods allow caller to choose either as needed. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
6221eb5974
commit
105c787e5a
@ -17,14 +17,21 @@ class ConversationBufferMemory(BaseChatMemory):
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
if self.return_messages:
|
||||
return self.chat_memory.messages
|
||||
else:
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[Any]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
@ -14,9 +14,24 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Number of messages to store in buffer."""
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
def buffer(self) -> Union[str, List[BaseMessage]]:
|
||||
"""String buffer of memory."""
|
||||
return self.chat_memory.messages
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
@ -26,14 +41,6 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
|
||||
buffer: Any = self.buffer[-self.k * 2 :] if self.k > 0 else []
|
||||
if not self.return_messages:
|
||||
buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: buffer}
|
||||
return {self.memory_key: self.buffer}
|
||||
|
Loading…
Reference in New Issue
Block a user