mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Add a conversation memory that combines a (optionally persistent) vectorstore history with a token buffer (#22155)
**langchain: ConversationVectorStoreTokenBufferMemory** -**Description:** This PR adds ConversationVectorStoreTokenBufferMemory. It is similar in concept to ConversationSummaryBufferMemory. It maintains an in-memory buffer of messages up to a preset token limit. After the limit is hit timestamped messages are written into a vectorstore retriever rather than into a summary. The user's prompt is then used to retrieve relevant fragments of the previous conversation. By persisting the vectorstore, one can maintain memory from session to session. -**Issue:** n/a -**Dependencies:** none -**Twitter handle:** Please no!!! - [X] **Add tests and docs**: I looked to see how the unit tests were written for the other ConversationMemory modules, but couldn't find anything other than a test for successful import. I need to know whether you are using pytest.mock or another fixture to simulate the LLM and vectorstore. In addition, I would like guidance on where to place the documentation. Should it be a notebook file in docs/docs? - [X] **Lint and test**: I am seeing some linting errors from a couple of modules unrelated to this PR. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
32f8f39974
commit
c314222796
@ -1,6 +1,7 @@
|
|||||||
"""__ModuleName__ document loader."""
|
"""__ModuleName__ document loader."""
|
||||||
|
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
from langchain_core.document_loaders.base import BaseLoader
|
from langchain_core.document_loaders.base import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
@ -48,6 +48,9 @@ from langchain.memory.summary import ConversationSummaryMemory
|
|||||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||||
|
from langchain.memory.vectorstore_token_buffer_memory import (
|
||||||
|
ConversationVectorStoreTokenBufferMemory, # avoid circular import
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_community.chat_message_histories import (
|
from langchain_community.chat_message_histories import (
|
||||||
@ -122,6 +125,7 @@ __all__ = [
|
|||||||
"ConversationSummaryBufferMemory",
|
"ConversationSummaryBufferMemory",
|
||||||
"ConversationSummaryMemory",
|
"ConversationSummaryMemory",
|
||||||
"ConversationTokenBufferMemory",
|
"ConversationTokenBufferMemory",
|
||||||
|
"ConversationVectorStoreTokenBufferMemory",
|
||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"DynamoDBChatMessageHistory",
|
"DynamoDBChatMessageHistory",
|
||||||
"ElasticsearchChatMessageHistory",
|
"ElasticsearchChatMessageHistory",
|
||||||
|
@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
Class for a conversation memory buffer with older messages stored in a vectorstore .
|
||||||
|
|
||||||
|
This implementats a conversation memory in which the messages are stored in a memory
|
||||||
|
buffer up to a specified token limit. When the limit is exceeded, older messages are
|
||||||
|
saved to a vectorstore backing database. The vectorstore can be made persistent across
|
||||||
|
sessions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import Field, PrivateAttr
|
||||||
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
|
|
||||||
|
from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory
|
||||||
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
DEFAULT_HISTORY_TEMPLATE = """
|
||||||
|
Current date and time: {current_time}.
|
||||||
|
|
||||||
|
Potentially relevant timestamped excerpts of previous conversations (you
|
||||||
|
do not need to use these if irrelevant):
|
||||||
|
{previous_history}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
|
||||||
|
"""Conversation chat memory with token limit and vectordb backing.
|
||||||
|
|
||||||
|
load_memory_variables() will return a dict with the key "history".
|
||||||
|
It contains background information retrieved from the vector store
|
||||||
|
plus recent lines of the current conversation.
|
||||||
|
|
||||||
|
To help the LLM understand the part of the conversation stored in the
|
||||||
|
vectorstore, each interaction is timestamped and the current date and
|
||||||
|
time is also provided in the history. A side effect of this is that the
|
||||||
|
LLM will have access to the current date and time.
|
||||||
|
|
||||||
|
Initialization arguments:
|
||||||
|
|
||||||
|
This class accepts all the initialization arguments of
|
||||||
|
ConversationTokenBufferMemory, such as `llm`. In addition, it
|
||||||
|
accepts the following additional arguments
|
||||||
|
|
||||||
|
retriever: (required) A VectorStoreRetriever object to use
|
||||||
|
as the vector backing store
|
||||||
|
|
||||||
|
split_chunk_size: (optional, 1000) Token chunk split size
|
||||||
|
for long messages generated by the AI
|
||||||
|
|
||||||
|
previous_history_template: (optional) Template used to format
|
||||||
|
the contents of the prompt history
|
||||||
|
|
||||||
|
|
||||||
|
Example using ChromaDB:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.memory.token_buffer_vectorstore_memory import (
|
||||||
|
ConversationVectorStoreTokenBufferMemory
|
||||||
|
)
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||||
|
from langchain_openai import OpenAI
|
||||||
|
|
||||||
|
embedder = HuggingFaceInstructEmbeddings(
|
||||||
|
query_instruction="Represent the query for retrieval: "
|
||||||
|
)
|
||||||
|
chroma = Chroma(collection_name="demo",
|
||||||
|
embedding_function=embedder,
|
||||||
|
collection_metadata={"hnsw:space": "cosine"},
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = chroma.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
'k': 5,
|
||||||
|
'score_threshold': 0.75,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_memory = ConversationVectorStoreTokenBufferMemory(
|
||||||
|
return_messages=True,
|
||||||
|
llm=OpenAI(),
|
||||||
|
retriever=retriever,
|
||||||
|
max_token_limit = 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_memory.save_context({"Human": "Hi there"},
|
||||||
|
{"AI": "Nice to meet you!"}
|
||||||
|
)
|
||||||
|
conversation_memory.save_context({"Human": "Nice day isn't it?"},
|
||||||
|
{"AI": "I love Wednesdays."}
|
||||||
|
)
|
||||||
|
conversation_memory.load_memory_variables({"input": "What time is it?"})
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||||
|
memory_key: str = "history"
|
||||||
|
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
|
||||||
|
split_chunk_size: int = 1000
|
||||||
|
|
||||||
|
_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None)
|
||||||
|
_timestamps: List[datetime] = PrivateAttr(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_retriever(self) -> VectorStoreRetrieverMemory:
|
||||||
|
"""Return a memory retriever from the passed retriever object."""
|
||||||
|
if self._memory_retriever is not None:
|
||||||
|
return self._memory_retriever
|
||||||
|
self._memory_retriever = VectorStoreRetrieverMemory(retriever=self.retriever)
|
||||||
|
return self._memory_retriever
|
||||||
|
|
||||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Return history and memory buffer."""
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
memory_variables = self.memory_retriever.load_memory_variables(inputs)
|
||||||
|
previous_history = memory_variables[self.memory_retriever.memory_key]
|
||||||
|
except AssertionError: # happens when db is empty
|
||||||
|
previous_history = ""
|
||||||
|
current_history = super().load_memory_variables(inputs)
|
||||||
|
template = SystemMessagePromptTemplate.from_template(
|
||||||
|
self.previous_history_template
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
template.format(
|
||||||
|
previous_history=previous_history,
|
||||||
|
current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
messages.extend(current_history[self.memory_key])
|
||||||
|
return {self.memory_key: messages}
|
||||||
|
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
"""Save context from this conversation to buffer. Pruned."""
|
||||||
|
BaseChatMemory.save_context(self, inputs, outputs)
|
||||||
|
self._timestamps.append(datetime.now().astimezone())
|
||||||
|
# Prune buffer if it exceeds max token limit
|
||||||
|
buffer = self.chat_memory.messages
|
||||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
|
if curr_buffer_length > self.max_token_limit:
|
||||||
|
while curr_buffer_length > self.max_token_limit:
|
||||||
|
self._pop_and_store_interaction(buffer)
|
||||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
|
|
||||||
|
def save_remainder(self) -> None:
|
||||||
|
"""
|
||||||
|
Save the remainder of the conversation buffer to the vector store.
|
||||||
|
|
||||||
|
This is useful if you have made the vectorstore persistent, in which
|
||||||
|
case this can be called before the end of the session to store the
|
||||||
|
remainder of the conversation.
|
||||||
|
"""
|
||||||
|
buffer = self.chat_memory.messages
|
||||||
|
while len(buffer) > 0:
|
||||||
|
self._pop_and_store_interaction(buffer)
|
||||||
|
|
||||||
|
def _pop_and_store_interaction(self, buffer: List[BaseMessage]) -> None:
|
||||||
|
input = buffer.pop(0)
|
||||||
|
output = buffer.pop(0)
|
||||||
|
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
|
||||||
|
# Split AI output into smaller chunks to avoid creating documents
|
||||||
|
# that will overflow the context window
|
||||||
|
ai_chunks = self._split_long_ai_text(str(output.content))
|
||||||
|
for index, chunk in enumerate(ai_chunks):
|
||||||
|
self.memory_retriever.save_context(
|
||||||
|
{"Human": f"<{timestamp}/00> {str(input.content)}"},
|
||||||
|
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_long_ai_text(self, text: str) -> List[str]:
|
||||||
|
splitter = RecursiveCharacterTextSplitter(chunk_size=self.split_chunk_size)
|
||||||
|
return [chunk.page_content for chunk in splitter.create_documents([text])]
|
@ -13,6 +13,7 @@ EXPECTED_ALL = [
|
|||||||
"ConversationSummaryBufferMemory",
|
"ConversationSummaryBufferMemory",
|
||||||
"ConversationSummaryMemory",
|
"ConversationSummaryMemory",
|
||||||
"ConversationTokenBufferMemory",
|
"ConversationTokenBufferMemory",
|
||||||
|
"ConversationVectorStoreTokenBufferMemory",
|
||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"DynamoDBChatMessageHistory",
|
"DynamoDBChatMessageHistory",
|
||||||
"ElasticsearchChatMessageHistory",
|
"ElasticsearchChatMessageHistory",
|
||||||
|
Loading…
Reference in New Issue
Block a user