mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
sequential chain from prompts
This commit is contained in:
parent
fc66a32c6f
commit
6dcbb74582
@ -1,10 +1,11 @@
|
|||||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping, print_text
|
from langchain.input import get_color_mapping, print_text
|
||||||
|
|
||||||
|
|
||||||
@ -135,3 +136,18 @@ class SimpleSequentialChain(Chain, BaseModel):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||||
return {self.output_key: _input}
|
return {self.output_key: _input}
|
||||||
|
|
||||||
|
|
||||||
|
def construct_sequential_llm_chain(
|
||||||
|
llm_chain: LLMChain, add_ons: List[Tuple[str, List[str], str]]
|
||||||
|
) -> SequentialChain:
|
||||||
|
base_prompt = llm_chain.prompt
|
||||||
|
chains = [llm_chain]
|
||||||
|
for template, input_vars, output_key in add_ons:
|
||||||
|
new_prompt = base_prompt.extend_prompt(template, input_vars)
|
||||||
|
new_llm_chain = LLMChain(
|
||||||
|
llm=llm_chain.llm, prompt=new_prompt, output_key=output_key
|
||||||
|
)
|
||||||
|
chains.append(new_llm_chain)
|
||||||
|
|
||||||
|
return SequentialChain(chains=chains, input_variables=llm_chain.input_keys)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -62,6 +64,12 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extend_prompt(
|
||||||
|
self, template: str, input_variables: List[str]
|
||||||
|
) -> BasePromptTemplate:
|
||||||
|
"""Extend the prompt with another template/input variables."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
"""Validate variable names do not restricted names."""
|
"""Validate variable names do not restricted names."""
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -41,6 +43,20 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
template_format: str = "f-string"
|
template_format: str = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string'."""
|
"""The format of the prompt template. Options are: 'f-string'."""
|
||||||
|
|
||||||
|
def extend_prompt(
|
||||||
|
self, template: str, input_variables: List[str]
|
||||||
|
) -> FewShotPromptTemplate:
|
||||||
|
"""Append to template and input variables."""
|
||||||
|
copied_prompt = self.copy(deep=True)
|
||||||
|
copied_prompt.suffix += template
|
||||||
|
copied_prompt.input_variables += input_variables
|
||||||
|
check_valid_template(
|
||||||
|
copied_prompt.prefix + copied_prompt.suffix,
|
||||||
|
copied_prompt.template_format,
|
||||||
|
copied_prompt.input_variables,
|
||||||
|
)
|
||||||
|
return copied_prompt
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||||
"""Check that one and only one of examples/example_selector are provided."""
|
"""Check that one and only one of examples/example_selector are provided."""
|
||||||
|
@ -36,6 +36,20 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def extend_prompt(
|
||||||
|
self, template: str, input_variables: List[str]
|
||||||
|
) -> PromptTemplate:
|
||||||
|
"""Append to template and input variables."""
|
||||||
|
copied_prompt = self.copy(deep=True)
|
||||||
|
copied_prompt.template += template
|
||||||
|
copied_prompt.input_variables += input_variables
|
||||||
|
check_valid_template(
|
||||||
|
copied_prompt.template,
|
||||||
|
copied_prompt.template_format,
|
||||||
|
copied_prompt.input_variables,
|
||||||
|
)
|
||||||
|
return copied_prompt
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the prompt with the inputs.
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
@ -5,7 +5,14 @@ import pytest
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.sequential import (
|
||||||
|
SequentialChain,
|
||||||
|
SimpleSequentialChain,
|
||||||
|
construct_sequential_llm_chain,
|
||||||
|
)
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
class FakeChain(Chain, BaseModel):
|
class FakeChain(Chain, BaseModel):
|
||||||
@ -138,3 +145,21 @@ def test_multi_output_errors() -> None:
|
|||||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_construct_sequential_llm_chain() -> None:
|
||||||
|
"""Test constructing simple sequential chain."""
|
||||||
|
prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"])
|
||||||
|
llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar")
|
||||||
|
add_ons = [("{bar} and what does it do?", ["bar"], "baz")]
|
||||||
|
chain = construct_sequential_llm_chain(llm_chain, add_ons)
|
||||||
|
|
||||||
|
expected_new_prompt = PromptTemplate(
|
||||||
|
template="what is {foo}?{bar} and what does it do?",
|
||||||
|
input_variables=["foo", "bar"],
|
||||||
|
)
|
||||||
|
expected_new_chain = LLMChain(
|
||||||
|
llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz"
|
||||||
|
)
|
||||||
|
expected_chains = [llm_chain, expected_new_chain]
|
||||||
|
assert chain.chains == expected_chains
|
||||||
|
Loading…
Reference in New Issue
Block a user