From b5325c212b32c17c828983c8f20b06d70a1abb23 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 19 Nov 2022 09:55:44 -0800 Subject: [PATCH] chain pipelines --- langchain/chains/pipeline.py | 71 ++++++++++++ langchain/chains/simple_pipeline.py | 59 ++++++++++ tests/unit_tests/chains/test_pipeline.py | 103 ++++++++++++++++++ .../unit_tests/chains/test_simple_pipeline.py | 59 ++++++++++ 4 files changed, 292 insertions(+) create mode 100644 langchain/chains/pipeline.py create mode 100644 langchain/chains/simple_pipeline.py create mode 100644 tests/unit_tests/chains/test_pipeline.py create mode 100644 tests/unit_tests/chains/test_simple_pipeline.py diff --git a/langchain/chains/pipeline.py b/langchain/chains/pipeline.py new file mode 100644 index 00000000000..2d0fcc7790a --- /dev/null +++ b/langchain/chains/pipeline.py @@ -0,0 +1,71 @@ +"""Chain pipeline where the outputs of one step feed directly into next.""" + +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.base import Chain + + +class Pipeline(Chain, BaseModel): + """Chain pipeline where the outputs of one step feed directly into next.""" + + chains: List[Chain] + input_variables: List[str] + output_variables: List[str] #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return self.output_variables + + @root_validator(pre=True) + def validate_chains(cls, values: Dict) -> Dict: + """Validate that the correct inputs exist for all chains.""" + chains = values["chains"] + input_variables = values["input_variables"] + known_variables = set(input_variables) + for chain in chains: + missing_vars = set(chain.input_keys).difference(known_variables) + if missing_vars: + raise ValueError(f"Missing required input keys: {missing_vars}") + overlapping_keys = known_variables.intersection(chain.output_keys) + if overlapping_keys: + raise ValueError( + f"Chain returned keys that already exist: {overlapping_keys}" + ) + known_variables |= set(chain.output_keys) + + if "output_variables" not in values: + values["output_variables"] = known_variables.difference(input_variables) + else: + missing_vars = known_variables.difference(values["output_variables"]) + if missing_vars: + raise ValueError( + f"Expected output variables that were not found: {missing_vars}." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + known_values = inputs.copy() + for chain in self.chains: + outputs = chain(known_values) + known_values.update(outputs) + return {k: known_values[k] for k in self.output_variables} diff --git a/langchain/chains/simple_pipeline.py b/langchain/chains/simple_pipeline.py new file mode 100644 index 00000000000..a38dce95c41 --- /dev/null +++ b/langchain/chains/simple_pipeline.py @@ -0,0 +1,59 @@ +"""Simple chain pipeline where the outputs of one step feed directly into next.""" + +from typing import Dict, List + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.base import Chain + + +class SimplePipeline(Chain, BaseModel): + """Simple chain pipeline where the outputs of one step feed directly into next.""" + + chains: List[Chain] + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + @root_validator() + def validate_chains(cls, values: Dict) -> Dict: + """Validate that chains are all single input/output.""" + for chain in values["chains"]: + if len(chain.input_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one input, got " + f"{chain} with {len(chain.input_keys)} inputs." + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one output, got " + f"{chain} with {len(chain.output_keys)} outputs." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + _input = inputs[self.input_key] + for chain in self.chains: + _input = chain.run(_input) + return {self.output_key: _input} diff --git a/tests/unit_tests/chains/test_pipeline.py b/tests/unit_tests/chains/test_pipeline.py new file mode 100644 index 00000000000..302f527d6c9 --- /dev/null +++ b/tests/unit_tests/chains/test_pipeline.py @@ -0,0 +1,103 @@ +"""Test pipeline functionality.""" +from typing import Dict, List + +import pytest +from pydantic import BaseModel + +from langchain.chains.base import Chain +from langchain.chains.pipeline import Pipeline + + +class FakeChain(Chain, BaseModel): + """Fake Chain for testing purposes.""" + + input_variables: List[str] + output_variables: List[str] + + @property + def input_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.output_variables + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + outputs = {} + for var in self.output_variables: + variables = [inputs[k] for k in self.input_variables] + outputs[var] = " ".join(variables) + "foo" + return outputs + + +def test_pipeline_usage_single_inputs() -> None: + """Test pipeline on single input chains.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"]) + output = pipeline({"foo": "123"}) + expected_output = {"bar": "123foo", "baz": "123foofoo", "foo": "123"} + assert output == expected_output + + +def test_pipeline_usage_multiple_inputs() -> None: + """Test pipeline on multiple input chains.""" + chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"]) + output = pipeline({"foo": "123", "test": "456"}) + expected_output = { + "bar": "123 456foo", + "baz": "123 456foo 123foo", + "foo": "123", + "test": "456", + } + assert output == expected_output + + +def test_pipeline_usage_multiple_outputs() -> None: + """Test pipeline usage on multiple output chains.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"]) + output = pipeline({"foo": "123"}) + expected_output = { + "bar": "123foo", + "baz": "123foo 123foo", + "foo": "123", + "test": "123foo", + } + assert output == expected_output + + +def test_pipeline_missing_inputs() -> None: + """Test error is raised when input variables are missing.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"]) + with pytest.raises(ValueError): + # Also needs "test" as an input + Pipeline(chains=[chain_1, chain_2], input_variables=["foo"]) + + +def test_pipeline_bad_outputs() -> None: + """Test error is raised when bad outputs are specified.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + # "test" is not present as an output variable. + Pipeline( + chains=[chain_1, chain_2], + input_variables=["foo"], + output_variables=["test"], + ) + + +def test_pipeline_overlapping_inputs() -> None: + """Test error is raised when input variables are overlapping.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + # "test" is specified as an input, but also is an output of one step + Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"]) diff --git a/tests/unit_tests/chains/test_simple_pipeline.py b/tests/unit_tests/chains/test_simple_pipeline.py new file mode 100644 index 00000000000..964aae6faa0 --- /dev/null +++ b/tests/unit_tests/chains/test_simple_pipeline.py @@ -0,0 +1,59 @@ +"""Test functionality around the simple pipeline chain.""" + +from typing import Dict, List + +import pytest +from pydantic import BaseModel + +from langchain.chains.base import Chain +from langchain.chains.simple_pipeline import SimplePipeline + + +class FakeChain(Chain, BaseModel): + """Fake chain for testing purposes.""" + + input_variables: List[str] + output_variables: List[str] + + @property + def input_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Input keys this chain returns.""" + return self.output_variables + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + outputs = {} + for var in self.output_variables: + variables = [inputs[k] for k in self.input_variables] + outputs[var] = " ".join(variables) + "foo" + return outputs + + +def test_pipeline_functionality() -> None: + """Test simple pipeline functionality.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + pipeline = SimplePipeline(chains=[chain_1, chain_2]) + output = pipeline({"input": "123"}) + expected_output = {"output": "123foofoo", "input": "123"} + assert output == expected_output + + +def test_multi_input_errors() -> None: + """Test pipeline errors if multiple input variables are expected.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) + chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"]) + with pytest.raises(ValueError): + SimplePipeline(chains=[chain_1, chain_2]) + + +def test_multi_output_errors() -> None: + """Test pipeline errors if multiple output variables are expected.""" + chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"]) + chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"]) + with pytest.raises(ValueError): + SimplePipeline(chains=[chain_1, chain_2])