Harrison/memory refactor (#1478)

moves memory to own module, factors out common stuff
This commit is contained in:
Harrison Chase
2023-03-07 07:59:37 -08:00
committed by GitHub
parent df6865cd52
commit 7bec461782
44 changed files with 3084 additions and 1610 deletions

View File

@@ -1,5 +1,5 @@
"""Test memory functionality."""
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from tests.unit_tests.llms.fake_llm import FakeLLM

View File

@@ -5,11 +5,12 @@ import pytest
from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain, Memory
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(Memory, BaseModel):
class FakeMemory(BaseMemory, BaseModel):
"""Fake memory class for testing purposes."""
@property

View File

@@ -1,14 +1,12 @@
"""Test conversation chain and memory."""
import pytest
from langchain.chains.base import Memory
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.conversation.memory import (
ConversationBufferMemory,
ConversationBufferWindowMemory,
ConversationSummaryMemory,
)
from langchain.memory.buffer import ConversationBufferMemory
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.summary import ConversationSummaryMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseMemory
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -16,14 +14,14 @@ def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "\nHuman: bar\nAssistant: foo"
assert memory.buffer == "Human: bar\nAssistant: foo"
def test_memory_human_prefix() -> None:
"""Test that human_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "\nFriend: bar\nAI: foo"
assert memory.buffer == "Friend: bar\nAI: foo"
def test_conversation_chain_works() -> None:
@@ -60,7 +58,7 @@ def test_conversation_chain_errors_bad_variable() -> None:
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
],
)
def test_conversation_memory(memory: Memory) -> None:
def test_conversation_memory(memory: BaseMemory) -> None:
"""Test basic conversation memory functionality."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
@@ -92,7 +90,7 @@ def test_conversation_memory(memory: Memory) -> None:
ConversationBufferWindowMemory(memory_key="baz"),
],
)
def test_clearing_conversation_memory(memory: Memory) -> None:
def test_clearing_conversation_memory(memory: BaseMemory) -> None:
"""Test clearing the conversation memory."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}

View File

@@ -1,4 +1,4 @@
from langchain.chains.base import SimpleMemory
from langchain.memory.simple import SimpleMemory
def test_simple_memory() -> None:

View File

@@ -4,8 +4,9 @@ from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain, SimpleMemory
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory.simple import SimpleMemory
class FakeChain(Chain, BaseModel):