mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
Harrison/sequential chains (#168)
add support for basic sequential chains
This commit is contained in:
140
tests/unit_tests/chains/test_sequential.py
Normal file
140
tests/unit_tests/chains/test_sequential.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Test pipeline functionality."""
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
"""Fake Chain for testing purposes."""
|
||||
|
||||
input_variables: List[str]
|
||||
output_variables: List[str]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
outputs[var] = " ".join(variables) + "foo"
|
||||
return outputs
|
||||
|
||||
|
||||
def test_sequential_usage_single_inputs() -> None:
|
||||
"""Test sequential on single input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = chain({"foo": "123"})
|
||||
expected_output = {"baz": "123foofoo", "foo": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_usage_multiple_inputs() -> None:
|
||||
"""Test sequential on multiple input chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||
output = chain({"foo": "123", "test": "456"})
|
||||
expected_output = {
|
||||
"baz": "123 456foo 123foo",
|
||||
"foo": "123",
|
||||
"test": "456",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_usage_multiple_outputs() -> None:
|
||||
"""Test sequential usage on multiple output chains."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
output = chain({"foo": "123"})
|
||||
expected_output = {
|
||||
"baz": "123foo 123foo",
|
||||
"foo": "123",
|
||||
}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_missing_inputs() -> None:
|
||||
"""Test error is raised when input variables are missing."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# Also needs "test" as an input
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
|
||||
|
||||
|
||||
def test_sequential_bad_outputs() -> None:
|
||||
"""Test error is raised when bad outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is not present as an output variable.
|
||||
SequentialChain(
|
||||
chains=[chain_1, chain_2],
|
||||
input_variables=["foo"],
|
||||
output_variables=["test"],
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_valid_outputs() -> None:
|
||||
"""Test chain runs when valid outputs are specified."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SequentialChain(
|
||||
chains=[chain_1, chain_2],
|
||||
input_variables=["foo"],
|
||||
output_variables=["bar", "baz"],
|
||||
)
|
||||
output = chain({"foo": "123"}, return_only_outputs=True)
|
||||
expected_output = {"baz": "123foofoo", "bar": "123foo"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sequential_overlapping_inputs() -> None:
|
||||
"""Test error is raised when input variables are overlapping."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
# "test" is specified as an input, but also is an output of one step
|
||||
SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"])
|
||||
|
||||
|
||||
def test_simple_sequential_functionality() -> None:
|
||||
"""Test simple sequential functionality."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
chain = SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
output = chain({"input": "123"})
|
||||
expected_output = {"output": "123foofoo", "input": "123"}
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_multi_input_errors() -> None:
|
||||
"""Test simple sequential errors if multiple input variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
|
||||
|
||||
def test_multi_output_errors() -> None:
|
||||
"""Test simple sequential errors if multiple output variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
Reference in New Issue
Block a user