From 9ea66bd1f93100d5bf818ca1c3ac130c240895de Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 08:40:57 -0800 Subject: [PATCH] wip chain pipelines --- langchain/chains/map/__init__.py | 0 langchain/chains/map/base.py | 73 +++++++++++++++++++++++++++++ langchain/chains/map/prompt.py | 0 langchain/chains/simple_pipeline.py | 61 ++++++++++++++++++++++++ 4 files changed, 134 insertions(+) create mode 100644 langchain/chains/map/__init__.py create mode 100644 langchain/chains/map/base.py create mode 100644 langchain/chains/map/prompt.py create mode 100644 langchain/chains/simple_pipeline.py diff --git a/langchain/chains/map/__init__.py b/langchain/chains/map/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/langchain/chains/map/base.py b/langchain/chains/map/base.py new file mode 100644 index 00000000000..d2c3beda0c5 --- /dev/null +++ b/langchain/chains/map/base.py @@ -0,0 +1,73 @@ +"""Chain that generates a list and then maps each output to another chain.""" + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from pydantic import BaseModel, Extra, root_validator +from typing import List, Dict + + +class MapChain(Chain, BaseModel): + """Chain that generates a list and then maps each output to another chain.""" + + llm_chain: LLMChain + map_chain: Chain + n: int + output_key_prefix: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + vars = self.llm_chain.prompt.input_variables + return [v for v in vars if v != "n"] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [f"{self.output_key_prefix}_{i}" for i in range(self.n)] + + @root_validator() + def validate_llm_chain(cls, values: Dict) -> Dict: + """Check that llm chain takes as input `n`.""" + input_vars = values["llm_chain"].prompt.input_variables + if "n" not in input_vars: + raise ValueError( + "For MapChains, `n` should be one of the input variables to " + f"llm_chain, only got {input_vars}" + ) + return values + + @root_validator() + def validate_map_chain(cls, values: Dict) -> Dict: + """Check that map chain takes a single input.""" + map_chain_inputs = values["map_chain"].input_keys + if len(map_chain_inputs) != 1: + raise ValueError( + "For MapChains, the map_chain should take a single input," + f" got {map_chain_inputs}." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + _inputs = {key: inputs[key] for key in self.input_keys} + _inputs["n"] = self.n + output = self.llm_chain.predict(**_inputs) + new_inputs = output.split("\n") + if len(new_inputs) != self.n: + raise ValueError( + f"Got {len(new_inputs)} items, but expected to get {self.n}" + ) + outputs = {self.map_chain.run(text) for text in new_inputs} + return outputs + diff --git a/langchain/chains/map/prompt.py b/langchain/chains/map/prompt.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/langchain/chains/simple_pipeline.py b/langchain/chains/simple_pipeline.py new file mode 100644 index 00000000000..a36b19336b1 --- /dev/null +++ b/langchain/chains/simple_pipeline.py @@ -0,0 +1,61 @@ +"""Simple chain pipeline where the outputs of one step feed directly into next.""" + +from langchain.chains.base import Chain +from pydantic import BaseModel, Extra, root_validator +from typing import List, Dict + + +class SimplePipeline(Chain, BaseModel): + """Simple chain pipeline where the outputs of one step feed directly into next.""" + + chains: List[Chain] + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + @root_validator() + def validate_chains(cls, values: Dict) -> Dict: + """Validate that chains are all single input/output.""" + for chain in values["chains"]: + if len(chain.input_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one input, got " + f"{chain} with {len(chain.input_keys)} inputs." + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Chains used in SimplePipeline should all have one output, got " + f"{chain} with {len(chain.output_keys)} outputs." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + _input = inputs[self.input_key] + for chain in self.chains: + _input = chain.run(_input) + # Clean the input + _input = _input.strip(' \t\n\r') + return {self.output_key: _input} + +