Add clear() method for Memory (#305)

a simple helper to clear the buffer in `Conversation*Memory` classes
This commit is contained in:
Shobith Alva
2022-12-11 07:09:06 -08:00
committed by GitHub
parent e02d6b2288
commit 19a9fa16a9
3 changed files with 37 additions and 0 deletions

View File

@@ -4,6 +4,7 @@ import pytest
from langchain.chains.base import Memory
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversation.memory import (
ConversationalBufferWindowMemory,
ConversationBufferMemory,
ConversationSummaryMemory,
)
@@ -66,3 +67,23 @@ def test_conversation_memory(memory: Memory) -> None:
bad_outputs = {"foo": "bar", "foo1": "bar"}
with pytest.raises(ValueError):
memory.save_context(good_inputs, bad_outputs)
@pytest.mark.parametrize(
"memory",
[
ConversationBufferMemory(memory_key="baz"),
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
ConversationalBufferWindowMemory(memory_key="baz"),
],
)
def test_clearing_conversation_memory(memory: Memory) -> None:
"""Test clearing the conversation memory."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because these is one variable.
good_outputs = {"bar": "foo"}
memory.save_context(good_inputs, good_outputs)
memory.clear()
assert memory.load_memory_variables({}) == {"baz": ""}