From cd3f8582cbd44bec144c438462855991c89f4297 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 1 May 2023 20:55:56 -0700 Subject: [PATCH] Harrison/combined memory (#3935) Co-authored-by: engkheng <60956360+outday29@users.noreply.github.com> --- langchain/memory/combined.py | 20 +++++++++- .../unit_tests/memory/test_combined_memory.py | 37 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/memory/test_combined_memory.py diff --git a/langchain/memory/combined.py b/langchain/memory/combined.py index 7969ca4689e..5d6574bcef0 100644 --- a/langchain/memory/combined.py +++ b/langchain/memory/combined.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Set + +from pydantic import validator from langchain.schema import BaseMemory @@ -9,6 +11,22 @@ class CombinedMemory(BaseMemory): memories: List[BaseMemory] """For tracking all the memories that should be accessed.""" + @validator("memories") + def check_repeated_memory_variable( + cls, value: List[BaseMemory] + ) -> List[BaseMemory]: + all_variables: Set[str] = set() + for val in value: + overlap = all_variables.intersection(val.memory_variables) + if overlap: + raise ValueError( + f"The same variables {overlap} are found in multiple" + "memory object, which is not allowed by CombinedMemory." + ) + all_variables |= set(val.memory_variables) + + return value + @property def memory_variables(self) -> List[str]: """All the memory variables that this instance provides.""" diff --git a/tests/unit_tests/memory/test_combined_memory.py b/tests/unit_tests/memory/test_combined_memory.py new file mode 100644 index 00000000000..dcaf240183e --- /dev/null +++ b/tests/unit_tests/memory/test_combined_memory.py @@ -0,0 +1,37 @@ +"""Test for CombinedMemory class""" +# from langchain.prompts import PromptTemplate +from typing import List + +import pytest + +from langchain.memory import CombinedMemory, ConversationBufferMemory + + +@pytest.fixture() +def example_memory() -> List[ConversationBufferMemory]: + example_1 = ConversationBufferMemory(memory_key="foo") + example_2 = ConversationBufferMemory(memory_key="bar") + example_3 = ConversationBufferMemory(memory_key="bar") + return [example_1, example_2, example_3] + + +def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None: + """Test basic functionality of methods exposed by class""" + combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]]) + assert combined_memory.memory_variables == ["foo", "bar"] + assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""} + combined_memory.save_context( + {"input": "Hello there"}, {"output": "Hello, how can I help you?"} + ) + assert combined_memory.load_memory_variables({}) == { + "foo": "Human: Hello there\nAI: Hello, how can I help you?", + "bar": "Human: Hello there\nAI: Hello, how can I help you?", + } + combined_memory.clear() + assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""} + + +def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None: + """Test raising error when repeated memory variables found""" + with pytest.raises(ValueError): + CombinedMemory(memories=[example_memory[1], example_memory[2]])