mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
add llm for loop
This commit is contained in:
parent
a57e74996f
commit
3ef44f41b7
290
docs/examples/chains/llm_for_loop.ipynb
Normal file
290
docs/examples/chains/llm_for_loop.ipynb
Normal file
@ -0,0 +1,290 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "4c475754",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.prompts.base import BaseOutputParser\n",
|
||||
"from langchain import OpenAI, LLMChain\n",
|
||||
"from langchain.chains.llm_for_loop import LLMForLoopChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "efcdd239",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# First we make a chain that generates the list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "2b1884f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional\n",
|
||||
"import re\n",
|
||||
"class ListOutputParser(BaseOutputParser):\n",
|
||||
" \n",
|
||||
" def __init__(self, regex: Optional[str] = None):\n",
|
||||
" self.regex=regex\n",
|
||||
" \n",
|
||||
" def parse(self, text: str) -> list:\n",
|
||||
" splits = [t for t in text.split(\"\\n\") if t]\n",
|
||||
" if self.regex is not None:\n",
|
||||
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
|
||||
" return splits"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "b2b7f8fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
|
||||
"\n",
|
||||
"The format of your lists should be:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"List:\n",
|
||||
"- Item 1\n",
|
||||
"- Item 2\n",
|
||||
"...\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Begin!:\n",
|
||||
"\n",
|
||||
"User input: {input}\n",
|
||||
"List:\"\"\"\n",
|
||||
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
|
||||
"\n",
|
||||
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"id": "2f8ea6ba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
|
||||
]
|
||||
},
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.predict_and_parse(input=\"top 5 ev companies\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "1fdfc7cb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Next we generate the chain that we run over each item"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "0b8f115a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"For the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: {company}\n",
|
||||
"Explanation of their name:\"\"\"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
|
||||
"\n",
|
||||
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "6d636881",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Tesla\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"explanation_chain.predict(company=\"Tesla\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "3c236dd3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Now we combine them"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "941b6389",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "98c39dbc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Tesla\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Nissan\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: BMW\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: BYD\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Volkswagen\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
|
||||
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
|
||||
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
|
||||
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
|
||||
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a2c1803",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs)[self.output_key]
|
||||
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
51
langchain/chains/llm_for_loop.py
Normal file
51
langchain/chains/llm_for_loop.py
Normal file
@ -0,0 +1,51 @@
|
||||
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class LLMForLoopChain(Chain, BaseModel):
|
||||
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain to use to generate multiple items."""
|
||||
apply_chain: Chain
|
||||
"""Chain to apply to each item that is generated."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def run_list(self, **kwargs: Any) -> List[str]:
|
||||
"""Get list from LLM chain and then run chain on each item."""
|
||||
output_items = self.llm_chain.predict_and_parse(**kwargs)
|
||||
res = []
|
||||
for item in output_items:
|
||||
res.append(self.apply_chain.run(item))
|
||||
return res
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
res = self.run_list(**inputs)
|
||||
return {self.output_key: "\n\n".join(res)}
|
@ -1,8 +1,8 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from pydantic import BaseModel, Field, root_validator, Extra
|
||||
|
||||
from langchain.formatting import formatter
|
||||
|
||||
@ -29,7 +29,7 @@ def check_valid_template(
|
||||
raise ValueError("Invalid prompt schema.")
|
||||
|
||||
|
||||
class OutputParser(ABC):
|
||||
class BaseOutputParser(ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
|
||||
@abstractmethod
|
||||
@ -37,22 +37,21 @@ class OutputParser(ABC):
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
|
||||
class DefaultParser(OutputParser):
|
||||
"""Just return the text."""
|
||||
|
||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel, ABC):
|
||||
"""Base prompt should expose the format method, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
output_parser: OutputParser = Field(default_factory=DefaultParser)
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not restricted names."""
|
||||
|
Loading…
Reference in New Issue
Block a user