mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
Harrison/combined memory (#3935)
Co-authored-by: engkheng <60956360+outday29@users.noreply.github.com>
This commit is contained in:
parent
c4cb55a0c5
commit
cd3f8582cb
@ -1,4 +1,6 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import BaseMemory
|
||||||
|
|
||||||
@ -9,6 +11,22 @@ class CombinedMemory(BaseMemory):
|
|||||||
memories: List[BaseMemory]
|
memories: List[BaseMemory]
|
||||||
"""For tracking all the memories that should be accessed."""
|
"""For tracking all the memories that should be accessed."""
|
||||||
|
|
||||||
|
@validator("memories")
|
||||||
|
def check_repeated_memory_variable(
|
||||||
|
cls, value: List[BaseMemory]
|
||||||
|
) -> List[BaseMemory]:
|
||||||
|
all_variables: Set[str] = set()
|
||||||
|
for val in value:
|
||||||
|
overlap = all_variables.intersection(val.memory_variables)
|
||||||
|
if overlap:
|
||||||
|
raise ValueError(
|
||||||
|
f"The same variables {overlap} are found in multiple"
|
||||||
|
"memory object, which is not allowed by CombinedMemory."
|
||||||
|
)
|
||||||
|
all_variables |= set(val.memory_variables)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> List[str]:
|
||||||
"""All the memory variables that this instance provides."""
|
"""All the memory variables that this instance provides."""
|
||||||
|
37
tests/unit_tests/memory/test_combined_memory.py
Normal file
37
tests/unit_tests/memory/test_combined_memory.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""Test for CombinedMemory class"""
|
||||||
|
# from langchain.prompts import PromptTemplate
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.memory import CombinedMemory, ConversationBufferMemory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def example_memory() -> List[ConversationBufferMemory]:
|
||||||
|
example_1 = ConversationBufferMemory(memory_key="foo")
|
||||||
|
example_2 = ConversationBufferMemory(memory_key="bar")
|
||||||
|
example_3 = ConversationBufferMemory(memory_key="bar")
|
||||||
|
return [example_1, example_2, example_3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_functionality(example_memory: List[ConversationBufferMemory]) -> None:
|
||||||
|
"""Test basic functionality of methods exposed by class"""
|
||||||
|
combined_memory = CombinedMemory(memories=[example_memory[0], example_memory[1]])
|
||||||
|
assert combined_memory.memory_variables == ["foo", "bar"]
|
||||||
|
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||||||
|
combined_memory.save_context(
|
||||||
|
{"input": "Hello there"}, {"output": "Hello, how can I help you?"}
|
||||||
|
)
|
||||||
|
assert combined_memory.load_memory_variables({}) == {
|
||||||
|
"foo": "Human: Hello there\nAI: Hello, how can I help you?",
|
||||||
|
"bar": "Human: Hello there\nAI: Hello, how can I help you?",
|
||||||
|
}
|
||||||
|
combined_memory.clear()
|
||||||
|
assert combined_memory.load_memory_variables({}) == {"foo": "", "bar": ""}
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_memory_var(example_memory: List[ConversationBufferMemory]) -> None:
|
||||||
|
"""Test raising error when repeated memory variables found"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CombinedMemory(memories=[example_memory[1], example_memory[2]])
|
Loading…
Reference in New Issue
Block a user