mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 12:06:43 +00:00
sequential chain from prompts
This commit is contained in:
parent
fc66a32c6f
commit
6dcbb74582
@ -1,10 +1,11 @@
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
|
||||
|
||||
@ -135,3 +136,18 @@ class SimpleSequentialChain(Chain, BaseModel):
|
||||
if self.verbose:
|
||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||
return {self.output_key: _input}
|
||||
|
||||
|
||||
def construct_sequential_llm_chain(
|
||||
llm_chain: LLMChain, add_ons: List[Tuple[str, List[str], str]]
|
||||
) -> SequentialChain:
|
||||
base_prompt = llm_chain.prompt
|
||||
chains = [llm_chain]
|
||||
for template, input_vars, output_key in add_ons:
|
||||
new_prompt = base_prompt.extend_prompt(template, input_vars)
|
||||
new_llm_chain = LLMChain(
|
||||
llm=llm_chain.llm, prompt=new_prompt, output_key=output_key
|
||||
)
|
||||
chains.append(new_llm_chain)
|
||||
|
||||
return SequentialChain(chains=chains, input_variables=llm_chain.input_keys)
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@ -62,6 +64,12 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def extend_prompt(
|
||||
self, template: str, input_variables: List[str]
|
||||
) -> BasePromptTemplate:
|
||||
"""Extend the prompt with another template/input variables."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not restricted names."""
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
@ -41,6 +43,20 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string'."""
|
||||
|
||||
def extend_prompt(
|
||||
self, template: str, input_variables: List[str]
|
||||
) -> FewShotPromptTemplate:
|
||||
"""Append to template and input variables."""
|
||||
copied_prompt = self.copy(deep=True)
|
||||
copied_prompt.suffix += template
|
||||
copied_prompt.input_variables += input_variables
|
||||
check_valid_template(
|
||||
copied_prompt.prefix + copied_prompt.suffix,
|
||||
copied_prompt.template_format,
|
||||
copied_prompt.input_variables,
|
||||
)
|
||||
return copied_prompt
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
|
@ -36,6 +36,20 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def extend_prompt(
|
||||
self, template: str, input_variables: List[str]
|
||||
) -> PromptTemplate:
|
||||
"""Append to template and input variables."""
|
||||
copied_prompt = self.copy(deep=True)
|
||||
copied_prompt.template += template
|
||||
copied_prompt.input_variables += input_variables
|
||||
check_valid_template(
|
||||
copied_prompt.template,
|
||||
copied_prompt.template_format,
|
||||
copied_prompt.input_variables,
|
||||
)
|
||||
return copied_prompt
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
|
@ -5,7 +5,14 @@ import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sequential import (
|
||||
SequentialChain,
|
||||
SimpleSequentialChain,
|
||||
construct_sequential_llm_chain,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class FakeChain(Chain, BaseModel):
|
||||
@ -138,3 +145,21 @@ def test_multi_output_errors() -> None:
|
||||
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
|
||||
with pytest.raises(ValueError):
|
||||
SimpleSequentialChain(chains=[chain_1, chain_2])
|
||||
|
||||
|
||||
def test_construct_sequential_llm_chain() -> None:
|
||||
"""Test constructing simple sequential chain."""
|
||||
prompt = PromptTemplate(template="what is {foo}?", input_variables=["foo"])
|
||||
llm_chain = LLMChain(llm=FakeLLM(), prompt=prompt, output_key="bar")
|
||||
add_ons = [("{bar} and what does it do?", ["bar"], "baz")]
|
||||
chain = construct_sequential_llm_chain(llm_chain, add_ons)
|
||||
|
||||
expected_new_prompt = PromptTemplate(
|
||||
template="what is {foo}?{bar} and what does it do?",
|
||||
input_variables=["foo", "bar"],
|
||||
)
|
||||
expected_new_chain = LLMChain(
|
||||
llm=FakeLLM(), prompt=expected_new_prompt, output_key="baz"
|
||||
)
|
||||
expected_chains = [llm_chain, expected_new_chain]
|
||||
assert chain.chains == expected_chains
|
||||
|
Loading…
Reference in New Issue
Block a user