Harrison/improve memory (#432)

add AI prefix

add new type of memory

Co-authored-by: Jason <chisanch@usc.edu>
This commit is contained in:
Harrison Chase
2022-12-27 08:23:51 -05:00
committed by GitHub
parent 150b67de10
commit f8b605293f
3 changed files with 134 additions and 4 deletions

View File

@@ -0,0 +1,31 @@
"""Test memory functionality."""
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_summary_buffer_memory_no_buffer_yet() -> None:
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
output = memory.load_memory_variables({})
assert output == {"baz": ""}
def test_summary_buffer_memory_buffer_only() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == ["Human: bar\nAI: foo"]
output = memory.load_memory_variables({})
assert output == {"baz": "Human: bar\nAI: foo"}
def test_summary_buffer_memory_summary() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(
llm=FakeLLM(), memory_key="baz", max_token_limit=13
)
memory.save_context({"input": "bar"}, {"output": "foo"})
memory.save_context({"input": "bar1"}, {"output": "foo1"})
assert memory.buffer == ["Human: bar1\nAI: foo1"]
output = memory.load_memory_variables({})
assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"}

View File

@@ -12,6 +12,13 @@ from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "\nHuman: bar\nAssistant: foo"
def test_conversation_chain_works() -> None:
"""Test that conversation chain works in basic setting."""
llm = FakeLLM()
@@ -42,6 +49,7 @@ def test_conversation_chain_errors_bad_variable() -> None:
"memory",
[
ConversationBufferMemory(memory_key="baz"),
ConversationalBufferWindowMemory(memory_key="baz"),
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
],
)
@@ -81,7 +89,7 @@ def test_clearing_conversation_memory(memory: Memory) -> None:
"""Test clearing the conversation memory."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because these is one variable.
# This is a good output because there is one variable.
good_outputs = {"bar": "foo"}
memory.save_context(good_inputs, good_outputs)