mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
Add clear()
method for Memory
(#305)
a simple helper to clear the buffer in `Conversation*Memory` classes
This commit is contained in:
parent
e02d6b2288
commit
19a9fa16a9
@ -29,6 +29,10 @@ class Memory(BaseModel, ABC):
|
|||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save the context of this model run to memory."""
|
"""Save the context of this model run to memory."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
@ -46,6 +46,10 @@ class ConversationBufferMemory(Memory, BaseModel):
|
|||||||
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
||||||
self.buffer += "\n" + "\n".join([human, ai])
|
self.buffer += "\n" + "\n".join([human, ai])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
self.buffer = ""
|
||||||
|
|
||||||
|
|
||||||
class ConversationalBufferWindowMemory(Memory, BaseModel):
|
class ConversationalBufferWindowMemory(Memory, BaseModel):
|
||||||
"""Buffer for storing conversation memory."""
|
"""Buffer for storing conversation memory."""
|
||||||
@ -75,6 +79,10 @@ class ConversationalBufferWindowMemory(Memory, BaseModel):
|
|||||||
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
||||||
self.buffer.append("\n".join([human, ai]))
|
self.buffer.append("\n".join([human, ai]))
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
self.buffer = []
|
||||||
|
|
||||||
|
|
||||||
class ConversationSummaryMemory(Memory, BaseModel):
|
class ConversationSummaryMemory(Memory, BaseModel):
|
||||||
"""Conversation summarizer to memory."""
|
"""Conversation summarizer to memory."""
|
||||||
@ -118,3 +126,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
|||||||
new_lines = "\n".join([human, ai])
|
new_lines = "\n".join([human, ai])
|
||||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
|
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
self.buffer = ""
|
||||||
|
@ -4,6 +4,7 @@ import pytest
|
|||||||
from langchain.chains.base import Memory
|
from langchain.chains.base import Memory
|
||||||
from langchain.chains.conversation.base import ConversationChain
|
from langchain.chains.conversation.base import ConversationChain
|
||||||
from langchain.chains.conversation.memory import (
|
from langchain.chains.conversation.memory import (
|
||||||
|
ConversationalBufferWindowMemory,
|
||||||
ConversationBufferMemory,
|
ConversationBufferMemory,
|
||||||
ConversationSummaryMemory,
|
ConversationSummaryMemory,
|
||||||
)
|
)
|
||||||
@ -66,3 +67,23 @@ def test_conversation_memory(memory: Memory) -> None:
|
|||||||
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
bad_outputs = {"foo": "bar", "foo1": "bar"}
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
memory.save_context(good_inputs, bad_outputs)
|
memory.save_context(good_inputs, bad_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"memory",
|
||||||
|
[
|
||||||
|
ConversationBufferMemory(memory_key="baz"),
|
||||||
|
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
|
||||||
|
ConversationalBufferWindowMemory(memory_key="baz"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clearing_conversation_memory(memory: Memory) -> None:
|
||||||
|
"""Test clearing the conversation memory."""
|
||||||
|
# This is a good input because the input is not the same as baz.
|
||||||
|
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||||
|
# This is a good output because these is one variable.
|
||||||
|
good_outputs = {"bar": "foo"}
|
||||||
|
memory.save_context(good_inputs, good_outputs)
|
||||||
|
|
||||||
|
memory.clear()
|
||||||
|
assert memory.load_memory_variables({}) == {"baz": ""}
|
||||||
|
Loading…
Reference in New Issue
Block a user