mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
fix(memory): allow internal chains to use memory (#6769)
Fixed #6768. This is a workaround only. I think a better longer-term solution is for chains to declare how many input variables they *actually* need (as opposed to ones that are in the prompt, where some may be satisfied by the memory). Then, a wrapping chain can check the input match against the actual input variables. @hwchase17
This commit is contained in:
parent
488d2d5da9
commit
f307ca094b
@ -62,6 +62,9 @@ class SequentialChain(Chain):
|
|||||||
|
|
||||||
for chain in chains:
|
for chain in chains:
|
||||||
missing_vars = set(chain.input_keys).difference(known_variables)
|
missing_vars = set(chain.input_keys).difference(known_variables)
|
||||||
|
if chain.memory:
|
||||||
|
missing_vars = missing_vars.difference(chain.memory.memory_variables)
|
||||||
|
|
||||||
if missing_vars:
|
if missing_vars:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Missing required input keys: {missing_vars}, "
|
f"Missing required input keys: {missing_vars}, "
|
||||||
|
@ -6,6 +6,7 @@ import pytest
|
|||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.memory.simple import SimpleMemory
|
from langchain.memory.simple import SimpleMemory
|
||||||
|
|
||||||
|
|
||||||
@ -81,6 +82,21 @@ def test_sequential_usage_memory() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_internal_chain_use_memory() -> None:
|
||||||
|
"""Test sequential usage with memory for one of the internal chains."""
|
||||||
|
memory = ConversationBufferMemory(memory_key="bla")
|
||||||
|
memory.save_context({"input": "yo"}, {"output": "ya"})
|
||||||
|
chain_1 = FakeChain(
|
||||||
|
input_variables=["foo", "bla"], output_variables=["bar"], memory=memory
|
||||||
|
)
|
||||||
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
|
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||||
|
output = chain({"foo": "123"})
|
||||||
|
print("HEYYY OUTPUT", output)
|
||||||
|
expected_output = {"foo": "123", "baz": "123 Human: yo\nAI: yafoofoo"}
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_sequential_usage_multiple_outputs() -> None:
|
def test_sequential_usage_multiple_outputs() -> None:
|
||||||
"""Test sequential usage on multiple output chains."""
|
"""Test sequential usage on multiple output chains."""
|
||||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||||
|
Loading…
Reference in New Issue
Block a user