From 3ef44f41b768039a339a937cb0bbdaf54956e42c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 26 Nov 2022 13:44:38 -0800 Subject: [PATCH] add llm for loop --- docs/examples/chains/llm_for_loop.ipynb | 290 ++++++++++++++++++++++++ langchain/chains/llm.py | 10 +- langchain/chains/llm_for_loop.py | 51 +++++ langchain/prompts/base.py | 21 +- 4 files changed, 360 insertions(+), 12 deletions(-) create mode 100644 docs/examples/chains/llm_for_loop.ipynb create mode 100644 langchain/chains/llm_for_loop.py diff --git a/docs/examples/chains/llm_for_loop.ipynb b/docs/examples/chains/llm_for_loop.ipynb new file mode 100644 index 00000000000..c04fb83909c --- /dev/null +++ b/docs/examples/chains/llm_for_loop.ipynb @@ -0,0 +1,290 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "id": "4c475754", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.prompts.base import BaseOutputParser\n", + "from langchain import OpenAI, LLMChain\n", + "from langchain.chains.llm_for_loop import LLMForLoopChain" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "efcdd239", + "metadata": {}, + "outputs": [], + "source": [ + "# First we make a chain that generates the list" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2b1884f5", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional\n", + "import re\n", + "class ListOutputParser(BaseOutputParser):\n", + " \n", + " def __init__(self, regex: Optional[str] = None):\n", + " self.regex=regex\n", + " \n", + " def parse(self, text: str) -> list:\n", + " splits = [t for t in text.split(\"\\n\") if t]\n", + " if self.regex is not None:\n", + " splits = [re.match(self.regex, s).group(1) for s in splits]\n", + " return splits" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "b2b7f8fa", + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n", + "\n", + "The format of your lists should be:\n", + "\n", + "```\n", + "List:\n", + "- Item 1\n", + "- Item 2\n", + "...\n", + "```\n", + "\n", + "Begin!:\n", + "\n", + "User input: {input}\n", + "List:\"\"\"\n", + "output_parser = ListOutputParser(regex=\"- (.*)\")\n", + "prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n", + "\n", + "chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "2f8ea6ba", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.predict_and_parse(input=\"top 5 ev companies\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "1fdfc7cb", + "metadata": {}, + "outputs": [], + "source": [ + "# Next we generate the chain that we run over each item" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "0b8f115a", + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"For the following company, explain the origin of their name:\n", + "\n", + "Company: {company}\n", + "Explanation of their name:\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n", + "\n", + "explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "6d636881", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: Tesla\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explanation_chain.predict(company=\"Tesla\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "3c236dd3", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we combine them" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "941b6389", + "metadata": {}, + "outputs": [], + "source": [ + "for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "98c39dbc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: Tesla\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: Nissan\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: BMW\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: BYD\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n", + "\n", + "Company: Volkswagen\n", + "Explanation of their name:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n", + " '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n", + " \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n", + " '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n", + " '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for_loop_chain.run_list(input=\"top 5 ev companies\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a2c1803", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 58514bf1fd5..544628524c8 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -1,5 +1,5 @@ """Chain that just formats a prompt and calls an LLM.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from pydantic import BaseModel, Extra @@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel): completion = llm.predict(adjective="funny") """ return self(kwargs)[self.output_key] + + def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: + """Call predict and then parse the results.""" + result = self.predict(**kwargs) + if self.prompt.output_parser is not None: + return self.prompt.output_parser.parse(result) + else: + return result diff --git a/langchain/chains/llm_for_loop.py b/langchain/chains/llm_for_loop.py new file mode 100644 index 00000000000..7c2b63d060a --- /dev/null +++ b/langchain/chains/llm_for_loop.py @@ -0,0 +1,51 @@ + +from typing import Dict, List, Any + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain + + +class LLMForLoopChain(Chain, BaseModel): + """Chain that first uses an LLM to generate multiple items then loops over them.""" + + llm_chain: LLMChain + """LLM chain to use to generate multiple items.""" + apply_chain: Chain + """Chain to apply to each item that is generated.""" + output_key: str = "text" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return self.llm_chain.prompt.input_variables + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + def run_list(self, **kwargs: Any) -> List[str]: + """Get list from LLM chain and then run chain on each item.""" + output_items = self.llm_chain.predict_and_parse(**kwargs) + res = [] + for item in output_items: + res.append(self.apply_chain.run(item)) + return res + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + res = self.run_list(**inputs) + return {self.output_key: "\n\n".join(res)} diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 8d133cb65bf..849ea926e62 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,8 +1,8 @@ """BasePrompt schema definition.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, root_validator, Extra from langchain.formatting import formatter @@ -29,7 +29,7 @@ def check_valid_template( raise ValueError("Invalid prompt schema.") -class OutputParser(ABC): +class BaseOutputParser(ABC): """Class to parse the output of an LLM call.""" @abstractmethod @@ -37,22 +37,21 @@ class OutputParser(ABC): """Parse the output of an LLM call.""" -class DefaultParser(OutputParser): - """Just return the text.""" - - def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]: - """Parse the output of an LLM call.""" - return text - class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt.""" input_variables: List[str] """A list of the names of the variables the prompt template expects.""" - output_parser: OutputParser = Field(default_factory=DefaultParser) + output_parser: Optional[BaseOutputParser] = None """How to parse the output of calling an LLM on this formatted prompt.""" + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + @root_validator() def validate_variable_names(cls, values: Dict) -> Dict: """Validate variable names do not restricted names."""