Compare commits

...

4 Commits

Author SHA1 Message Date
Harrison Chase
620484f3ea cr 2022-11-19 09:36:22 -08:00
Harrison Chase
3fcc803880 Merge branch 'master' into harrison/chain_pipeline 2022-11-19 09:34:05 -08:00
Harrison Chase
9ce01f4281 cr 2022-11-16 08:28:36 -08:00
Harrison Chase
9ea66bd1f9 wip chain pipelines 2022-11-14 08:40:57 -08:00
6 changed files with 240 additions and 0 deletions

View File

View File

@@ -0,0 +1,74 @@
"""Chain that generates a list and then maps each output to another chain."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
class MapChain(Chain, BaseModel):
"""Chain that generates a list and then maps each output to another chain."""
llm_chain: LLMChain
map_chain: Chain
n: int
output_key_prefix: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
vars = self.llm_chain.prompt.input_variables
return [v for v in vars if v != "n"]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [f"{self.output_key_prefix}_{i}" for i in range(self.n)]
@root_validator()
def validate_llm_chain(cls, values: Dict) -> Dict:
"""Check that llm chain takes as input `n`."""
input_vars = values["llm_chain"].prompt.input_variables
if "n" not in input_vars:
raise ValueError(
"For MapChains, `n` should be one of the input variables to "
f"llm_chain, only got {input_vars}"
)
return values
@root_validator()
def validate_map_chain(cls, values: Dict) -> Dict:
"""Check that map chain takes a single input."""
map_chain_inputs = values["map_chain"].input_keys
if len(map_chain_inputs) != 1:
raise ValueError(
"For MapChains, the map_chain should take a single input,"
f" got {map_chain_inputs}."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_inputs = {key: inputs[key] for key in self.input_keys}
_inputs["n"] = self.n
output = self.llm_chain.predict(**_inputs)
new_inputs = output.split("\n")
if len(new_inputs) != self.n:
raise ValueError(
f"Got {len(new_inputs)} items, but expected to get {self.n}"
)
outputs = {self.map_chain.run(text) for text in new_inputs}
return outputs

View File

View File

@@ -0,0 +1,71 @@
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
class Pipeline(Chain, BaseModel):
"""Chain pipeline where the outputs of one step feed directly into next."""
chains: List[Chain]
input_variables: List[str]
output_variables: List[str] #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return self.output_variables
@root_validator(pre=True)
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
known_variables = set(input_variables)
for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables)
if missing_vars:
raise ValueError(f"Missing required input keys: {missing_vars}")
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(
f"Chain returned keys that already exist: {overlapping_keys}"
)
known_variables |= set(chain.output_keys)
if "output_variables" not in values:
values["output_variables"] = known_variables.difference(input_variables)
else:
missing_vars = known_variables.difference(values["output_variables"])
if missing_vars:
raise ValueError(
f"Expected output variables that were not found: {missing_vars}."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
known_values = inputs.copy()
for chain in self.chains:
outputs = chain(known_values)
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}

View File

@@ -0,0 +1,61 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
class SimplePipeline(Chain, BaseModel):
"""Simple chain pipeline where the outputs of one step feed directly into next."""
chains: List[Chain]
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that chains are all single input/output."""
for chain in values["chains"]:
if len(chain.input_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one input, got "
f"{chain} with {len(chain.input_keys)} inputs."
)
if len(chain.output_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one output, got "
f"{chain} with {len(chain.output_keys)} outputs."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
for chain in self.chains:
_input = chain.run(_input)
# Clean the input
_input = _input.strip()
return {self.output_key: _input}

View File

@@ -0,0 +1,34 @@
from typing import Dict, List
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.pipeline import Pipeline
class FakeChain(Chain, BaseModel):
input_variables: List[str]
output_variables: List[str]
@property
def input_keys(self) -> List[str]:
return self.input_variables
@property
def output_keys(self) -> List[str]:
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
outputs[var] = " ".join(self.input_variables) + "foo"
return outputs
def test_pipeline_usage() -> None:
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
output = pipeline({"foo": "123"})
breakpoint()