mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
add llm for loop
This commit is contained in:
parent
a57e74996f
commit
3ef44f41b7
290
docs/examples/chains/llm_for_loop.ipynb
Normal file
290
docs/examples/chains/llm_for_loop.ipynb
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"id": "4c475754",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.prompts.base import BaseOutputParser\n",
|
||||||
|
"from langchain import OpenAI, LLMChain\n",
|
||||||
|
"from langchain.chains.llm_for_loop import LLMForLoopChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"id": "efcdd239",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# First we make a chain that generates the list"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "2b1884f5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"import re\n",
|
||||||
|
"class ListOutputParser(BaseOutputParser):\n",
|
||||||
|
" \n",
|
||||||
|
" def __init__(self, regex: Optional[str] = None):\n",
|
||||||
|
" self.regex=regex\n",
|
||||||
|
" \n",
|
||||||
|
" def parse(self, text: str) -> list:\n",
|
||||||
|
" splits = [t for t in text.split(\"\\n\") if t]\n",
|
||||||
|
" if self.regex is not None:\n",
|
||||||
|
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
|
||||||
|
" return splits"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 30,
|
||||||
|
"id": "b2b7f8fa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
|
||||||
|
"\n",
|
||||||
|
"The format of your lists should be:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"List:\n",
|
||||||
|
"- Item 1\n",
|
||||||
|
"- Item 2\n",
|
||||||
|
"...\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Begin!:\n",
|
||||||
|
"\n",
|
||||||
|
"User input: {input}\n",
|
||||||
|
"List:\"\"\"\n",
|
||||||
|
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
|
||||||
|
"\n",
|
||||||
|
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 44,
|
||||||
|
"id": "2f8ea6ba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 44,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.predict_and_parse(input=\"top 5 ev companies\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
|
"id": "1fdfc7cb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Next we generate the chain that we run over each item"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 38,
|
||||||
|
"id": "0b8f115a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"\"\"For the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: {company}\n",
|
||||||
|
"Explanation of their name:\"\"\"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
|
||||||
|
"\n",
|
||||||
|
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 39,
|
||||||
|
"id": "6d636881",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: Tesla\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 39,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"explanation_chain.predict(company=\"Tesla\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 40,
|
||||||
|
"id": "3c236dd3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Now we combine them"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 41,
|
||||||
|
"id": "941b6389",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 45,
|
||||||
|
"id": "98c39dbc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: Tesla\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: Nissan\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: BMW\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: BYD\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||||
|
"\n",
|
||||||
|
"Company: Volkswagen\n",
|
||||||
|
"Explanation of their name:\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
|
||||||
|
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
|
||||||
|
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
|
||||||
|
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
|
||||||
|
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 45,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5a2c1803",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
"""Chain that just formats a prompt and calls an LLM."""
|
"""Chain that just formats a prompt and calls an LLM."""
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
|
|||||||
completion = llm.predict(adjective="funny")
|
completion = llm.predict(adjective="funny")
|
||||||
"""
|
"""
|
||||||
return self(kwargs)[self.output_key]
|
return self(kwargs)[self.output_key]
|
||||||
|
|
||||||
|
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||||
|
"""Call predict and then parse the results."""
|
||||||
|
result = self.predict(**kwargs)
|
||||||
|
if self.prompt.output_parser is not None:
|
||||||
|
return self.prompt.output_parser.parse(result)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
51
langchain/chains/llm_for_loop.py
Normal file
51
langchain/chains/llm_for_loop.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
|
class LLMForLoopChain(Chain, BaseModel):
|
||||||
|
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
||||||
|
|
||||||
|
llm_chain: LLMChain
|
||||||
|
"""LLM chain to use to generate multiple items."""
|
||||||
|
apply_chain: Chain
|
||||||
|
"""Chain to apply to each item that is generated."""
|
||||||
|
output_key: str = "text" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return self.llm_chain.prompt.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def run_list(self, **kwargs: Any) -> List[str]:
|
||||||
|
"""Get list from LLM chain and then run chain on each item."""
|
||||||
|
output_items = self.llm_chain.predict_and_parse(**kwargs)
|
||||||
|
res = []
|
||||||
|
for item in output_items:
|
||||||
|
res.append(self.apply_chain.run(item))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
res = self.run_list(**inputs)
|
||||||
|
return {self.output_key: "\n\n".join(res)}
|
@ -1,8 +1,8 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator, Extra
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.formatting import formatter
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ def check_valid_template(
|
|||||||
raise ValueError("Invalid prompt schema.")
|
raise ValueError("Invalid prompt schema.")
|
||||||
|
|
||||||
|
|
||||||
class OutputParser(ABC):
|
class BaseOutputParser(ABC):
|
||||||
"""Class to parse the output of an LLM call."""
|
"""Class to parse the output of an LLM call."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -37,22 +37,21 @@ class OutputParser(ABC):
|
|||||||
"""Parse the output of an LLM call."""
|
"""Parse the output of an LLM call."""
|
||||||
|
|
||||||
|
|
||||||
class DefaultParser(OutputParser):
|
|
||||||
"""Just return the text."""
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
|
||||||
"""Parse the output of an LLM call."""
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(BaseModel, ABC):
|
class BasePromptTemplate(BaseModel, ABC):
|
||||||
"""Base prompt should expose the format method, returning a prompt."""
|
"""Base prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
output_parser: OutputParser = Field(default_factory=DefaultParser)
|
output_parser: Optional[BaseOutputParser] = None
|
||||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
"""Validate variable names do not restricted names."""
|
"""Validate variable names do not restricted names."""
|
||||||
|
Loading…
Reference in New Issue
Block a user