add llm for loop

This commit is contained in:
Harrison Chase 2022-11-26 13:44:38 -08:00
parent a57e74996f
commit 3ef44f41b7
4 changed files with 360 additions and 12 deletions

View File

@ -0,0 +1,290 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"id": "4c475754",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.prompts.base import BaseOutputParser\n",
"from langchain import OpenAI, LLMChain\n",
"from langchain.chains.llm_for_loop import LLMForLoopChain"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "efcdd239",
"metadata": {},
"outputs": [],
"source": [
"# First we make a chain that generates the list"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2b1884f5",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"import re\n",
"class ListOutputParser(BaseOutputParser):\n",
" \n",
" def __init__(self, regex: Optional[str] = None):\n",
" self.regex=regex\n",
" \n",
" def parse(self, text: str) -> list:\n",
" splits = [t for t in text.split(\"\\n\") if t]\n",
" if self.regex is not None:\n",
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
" return splits"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b2b7f8fa",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
"\n",
"The format of your lists should be:\n",
"\n",
"```\n",
"List:\n",
"- Item 1\n",
"- Item 2\n",
"...\n",
"```\n",
"\n",
"Begin!:\n",
"\n",
"User input: {input}\n",
"List:\"\"\"\n",
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
"\n",
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "2f8ea6ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.predict_and_parse(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1fdfc7cb",
"metadata": {},
"outputs": [],
"source": [
"# Next we generate the chain that we run over each item"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "0b8f115a",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"For the following company, explain the origin of their name:\n",
"\n",
"Company: {company}\n",
"Explanation of their name:\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
"\n",
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "6d636881",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explanation_chain.predict(company=\"Tesla\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "3c236dd3",
"metadata": {},
"outputs": [],
"source": [
"# Now we combine them"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "941b6389",
"metadata": {},
"outputs": [],
"source": [
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "98c39dbc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Nissan\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BMW\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BYD\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Volkswagen\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a2c1803",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,5 +1,5 @@
"""Chain that just formats a prompt and calls an LLM.""" """Chain that just formats a prompt and calls an LLM."""
from typing import Any, Dict, List from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
completion = llm.predict(adjective="funny") completion = llm.predict(adjective="funny")
""" """
return self(kwargs)[self.output_key] return self(kwargs)[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result

View File

@ -0,0 +1,51 @@
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
class LLMForLoopChain(Chain, BaseModel):
"""Chain that first uses an LLM to generate multiple items then loops over them."""
llm_chain: LLMChain
"""LLM chain to use to generate multiple items."""
apply_chain: Chain
"""Chain to apply to each item that is generated."""
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.llm_chain.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def run_list(self, **kwargs: Any) -> List[str]:
"""Get list from LLM chain and then run chain on each item."""
output_items = self.llm_chain.predict_and_parse(**kwargs)
res = []
for item in output_items:
res.append(self.apply_chain.run(item))
return res
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
res = self.run_list(**inputs)
return {self.output_key: "\n\n".join(res)}

View File

@ -1,8 +1,8 @@
"""BasePrompt schema definition.""" """BasePrompt schema definition."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union, Optional
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator, Extra
from langchain.formatting import formatter from langchain.formatting import formatter
@ -29,7 +29,7 @@ def check_valid_template(
raise ValueError("Invalid prompt schema.") raise ValueError("Invalid prompt schema.")
class OutputParser(ABC): class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call.""" """Class to parse the output of an LLM call."""
@abstractmethod @abstractmethod
@ -37,22 +37,21 @@ class OutputParser(ABC):
"""Parse the output of an LLM call.""" """Parse the output of an LLM call."""
class DefaultParser(OutputParser):
"""Just return the text."""
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""
return text
class BasePromptTemplate(BaseModel, ABC): class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt.""" """Base prompt should expose the format method, returning a prompt."""
input_variables: List[str] input_variables: List[str]
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""
output_parser: OutputParser = Field(default_factory=DefaultParser) output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt.""" """How to parse the output of calling an LLM on this formatted prompt."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator() @root_validator()
def validate_variable_names(cls, values: Dict) -> Dict: def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not restricted names.""" """Validate variable names do not restricted names."""