"""Test pipeline functionality.""" from typing import Dict, List import pytest from pydantic import BaseModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import ( SequentialChain, SimpleSequentialChain, construct_sequential_llm_chain, ) from langchain.prompts.prompt import PromptTemplate from tests.unit_tests.llms.fake_llm import FakeLLM class FakeChain(Chain, BaseModel): """Fake Chain for testing purposes.""" input_variables: List[str] output_variables: List[str] @property def input_keys(self) -> List[str]: """Input keys this chain returns.""" return self.input_variables @property def output_keys(self) -> List[str]: """Input keys this chain returns.""" return self.output_variables def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: outputs = {} for var in self.output_variables: variables = [inputs[k] for k in self.input_variables] outputs[var] = f"{' '.join(variables)}foo" return outputs def test_sequential_usage_single_inputs() -> None: """Test sequential on single input chains.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) output = chain({"foo": "123"}) expected_output = {"baz": "123foofoo", "foo": "123"} assert output == expected_output def test_sequential_usage_multiple_inputs() -> None: """Test sequential on multiple input chains.""" chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) output = chain({"foo": "123", "test": "456"}) expected_output = { "baz": "123 456foo 123foo", "foo": "123", "test": "456", } assert output == expected_output def test_sequential_usage_multiple_outputs() -> None: """Test sequential usage on multiple output chains.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) output = chain({"foo": "123"}) expected_output = { "baz": "123foo 123foo", "foo": "123", } assert output == expected_output def test_sequential_missing_inputs() -> None: """Test error is raised when input variables are missing.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"]) with pytest.raises(ValueError): # Also needs "test" as an input SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"]) def test_sequential_bad_outputs() -> None: """Test error is raised when bad outputs are specified.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) with pytest.raises(ValueError): # "test" is not present as an output variable. SequentialChain( chains=[chain_1, chain_2], input_variables=["foo"], output_variables=["test"], ) def test_sequential_valid_outputs() -> None: """Test chain runs when valid outputs are specified.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain = SequentialChain( chains=[chain_1, chain_2], input_variables=["foo"], output_variables=["bar", "baz"], ) output = chain({"foo": "123"}, return_only_outputs=True) expected_output = {"baz": "123foofoo", "bar": "123foo"} assert output == expected_output def test_sequential_overlapping_inputs() -> None: """Test error is raised when input variables are overlapping.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) with pytest.raises(ValueError): # "test" is specified as an input, but also is an output of one step SequentialChain(chains=[chain_1, chain_2], input_variables=["foo", "test"]) def test_simple_sequential_functionality() -> None: """Test simple sequential functionality.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) chain = SimpleSequentialChain(chains=[chain_1, chain_2]) output = chain({"input": "123"}) expected_output = {"output": "123foofoo", "input": "123"} assert output == expected_output def test_multi_input_errors() -> None: """Test simple sequential errors if multiple input variables are expected.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) with pytest.raises(ValueError): SimpleSequentialChain(chains=[chain_1, chain_2]) def test_multi_output_errors() -> None: """Test simple sequential errors if multiple output variables are expected.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"]) chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) with pytest.raises(ValueError): SimpleSequentialChain(chains=[chain_1, chain_2]) def test_construct_sequential_llm_chain() -> None: """Test constructing simple sequential chain.""" prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"]) llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar") add_ons = [("{bar} and what does it do?", ["bar"], "baz")] chain = construct_sequential_llm_chain(llm_chain, add_ons) expected_new_prompt = PromptTemplate( template="what is {foo}?{bar} and what does it do?", input_variables=["foo", "bar"], ) expected_new_chain = LLMChain( llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz" ) expected_chains = [llm_chain, expected_new_chain] assert chain.chains == expected_chains