mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 13:31:53 +00:00
wip chain pipelines
This commit is contained in:
parent
ced29b816b
commit
9ea66bd1f9
0
langchain/chains/map/__init__.py
Normal file
0
langchain/chains/map/__init__.py
Normal file
73
langchain/chains/map/base.py
Normal file
73
langchain/chains/map/base.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Chain that generates a list and then maps each output to another chain."""
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class MapChain(Chain, BaseModel):
|
||||||
|
"""Chain that generates a list and then maps each output to another chain."""
|
||||||
|
|
||||||
|
llm_chain: LLMChain
|
||||||
|
map_chain: Chain
|
||||||
|
n: int
|
||||||
|
output_key_prefix: str = "output" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
vars = self.llm_chain.prompt.input_variables
|
||||||
|
return [v for v in vars if v != "n"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [f"{self.output_key_prefix}_{i}" for i in range(self.n)]
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_llm_chain(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that llm chain takes as input `n`."""
|
||||||
|
input_vars = values["llm_chain"].prompt.input_variables
|
||||||
|
if "n" not in input_vars:
|
||||||
|
raise ValueError(
|
||||||
|
"For MapChains, `n` should be one of the input variables to "
|
||||||
|
f"llm_chain, only got {input_vars}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_map_chain(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that map chain takes a single input."""
|
||||||
|
map_chain_inputs = values["map_chain"].input_keys
|
||||||
|
if len(map_chain_inputs) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"For MapChains, the map_chain should take a single input,"
|
||||||
|
f" got {map_chain_inputs}."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
_inputs = {key: inputs[key] for key in self.input_keys}
|
||||||
|
_inputs["n"] = self.n
|
||||||
|
output = self.llm_chain.predict(**_inputs)
|
||||||
|
new_inputs = output.split("\n")
|
||||||
|
if len(new_inputs) != self.n:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got {len(new_inputs)} items, but expected to get {self.n}"
|
||||||
|
)
|
||||||
|
outputs = {self.map_chain.run(text) for text in new_inputs}
|
||||||
|
return outputs
|
||||||
|
|
0
langchain/chains/map/prompt.py
Normal file
0
langchain/chains/map/prompt.py
Normal file
61
langchain/chains/simple_pipeline.py
Normal file
61
langchain/chains/simple_pipeline.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePipeline(Chain, BaseModel):
|
||||||
|
"""Simple chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
|
chains: List[Chain]
|
||||||
|
input_key: str = "input" #: :meta private:
|
||||||
|
output_key: str = "output" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_chains(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that chains are all single input/output."""
|
||||||
|
for chain in values["chains"]:
|
||||||
|
if len(chain.input_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chains used in SimplePipeline should all have one input, got "
|
||||||
|
f"{chain} with {len(chain.input_keys)} inputs."
|
||||||
|
)
|
||||||
|
if len(chain.output_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chains used in SimplePipeline should all have one output, got "
|
||||||
|
f"{chain} with {len(chain.output_keys)} outputs."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
_input = inputs[self.input_key]
|
||||||
|
for chain in self.chains:
|
||||||
|
_input = chain.run(_input)
|
||||||
|
# Clean the input
|
||||||
|
_input = _input.strip(' \t\n\r')
|
||||||
|
return {self.output_key: _input}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user