mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
Harrison/memory refactor (#1478)
moves memory to own module, factors out common stuff
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""Test memory functionality."""
|
||||
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
|
@@ -5,11 +5,12 @@ import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.chains.base import Chain, Memory
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeMemory(Memory, BaseModel):
|
||||
class FakeMemory(BaseMemory, BaseModel):
|
||||
"""Fake memory class for testing purposes."""
|
||||
|
||||
@property
|
||||
|
@@ -1,14 +1,12 @@
|
||||
"""Test conversation chain and memory."""
|
||||
import pytest
|
||||
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.conversation.memory import (
|
||||
ConversationBufferMemory,
|
||||
ConversationBufferWindowMemory,
|
||||
ConversationSummaryMemory,
|
||||
)
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@@ -16,14 +14,14 @@ def test_memory_ai_prefix() -> None:
|
||||
"""Test that ai_prefix in the memory component works."""
|
||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "\nHuman: bar\nAssistant: foo"
|
||||
assert memory.buffer == "Human: bar\nAssistant: foo"
|
||||
|
||||
|
||||
def test_memory_human_prefix() -> None:
|
||||
"""Test that human_prefix in the memory component works."""
|
||||
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "\nFriend: bar\nAI: foo"
|
||||
assert memory.buffer == "Friend: bar\nAI: foo"
|
||||
|
||||
|
||||
def test_conversation_chain_works() -> None:
|
||||
@@ -60,7 +58,7 @@ def test_conversation_chain_errors_bad_variable() -> None:
|
||||
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
|
||||
],
|
||||
)
|
||||
def test_conversation_memory(memory: Memory) -> None:
|
||||
def test_conversation_memory(memory: BaseMemory) -> None:
|
||||
"""Test basic conversation memory functionality."""
|
||||
# This is a good input because the input is not the same as baz.
|
||||
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||
@@ -92,7 +90,7 @@ def test_conversation_memory(memory: Memory) -> None:
|
||||
ConversationBufferWindowMemory(memory_key="baz"),
|
||||
],
|
||||
)
|
||||
def test_clearing_conversation_memory(memory: Memory) -> None:
|
||||
def test_clearing_conversation_memory(memory: BaseMemory) -> None:
|
||||
"""Test clearing the conversation memory."""
|
||||
# This is a good input because the input is not the same as baz.
|
||||
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from langchain.chains.base import SimpleMemory
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
|
||||
|
||||
def test_simple_memory() -> None:
|
||||
|
@@ -4,8 +4,9 @@ from typing import Dict, List
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain, SimpleMemory
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
|
Reference in New Issue
Block a user