diff --git a/langchain/memory/vectorstore.py b/langchain/memory/vectorstore.py index d5c40f26b99..60e0ae7acc2 100644 --- a/langchain/memory/vectorstore.py +++ b/langchain/memory/vectorstore.py @@ -1,6 +1,6 @@ """Class for a VectorStore-backed memory object.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from pydantic import Field @@ -25,6 +25,9 @@ class VectorStoreRetrieverMemory(BaseMemory): return_docs: bool = False """Whether or not to return the result of querying the database directly.""" + exclude_input_keys: Sequence[str] = Field(default_factory=tuple) + """Input keys to exclude in addition to memory key when constructing the document""" + @property def memory_variables(self) -> List[str]: """The list of keys emitted from the load_memory_variables method.""" @@ -55,10 +58,13 @@ class VectorStoreRetrieverMemory(BaseMemory): ) -> List[Document]: """Format context from this conversation to buffer.""" # Each document should only include the current turn, not the chat history - filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key} + exclude = set(self.exclude_input_keys) + exclude.add(self.memory_key) + filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude} texts = [ f"{k}: {v}" for k, v in list(filtered_inputs.items()) + list(outputs.items()) + if k not in exclude ] page_content = "\n".join(texts) return [Document(page_content=page_content)]