Compare commits

...

9 Commits

Author SHA1 Message Date
Harrison Chase
cc606180cd Merge branch 'master' into harrison/use_output_parser 2022-12-03 13:13:34 -08:00
Harrison Chase
f423bbc8ac Merge branch 'master' into harrison/use_output_parser 2022-12-03 13:13:01 -08:00
Harrison Chase
bfe50949f5 cr 2022-12-01 16:28:36 -08:00
Harrison Chase
9966fd0e05 cr 2022-12-01 16:27:54 -08:00
Harrison Chase
3ef44f41b7 add llm for loop 2022-11-26 13:44:38 -08:00
Harrison Chase
a57e74996f add output parser 2022-11-26 07:15:54 -08:00
Harrison Chase
67685b874e stash 2022-11-25 13:13:27 -08:00
Harrison Chase
9b674d3dc6 Merge branch 'master' into harrison/output_parser 2022-11-25 10:47:01 -08:00
Harrison Chase
c09fe1dfdf stash 2022-11-24 09:46:29 -08:00
4 changed files with 388 additions and 3 deletions

View File

@@ -0,0 +1,290 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"id": "4c475754",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.prompts.base import BaseOutputParser\n",
"from langchain import OpenAI, LLMChain\n",
"from langchain.chains.llm_for_loop import LLMForLoopChain"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "efcdd239",
"metadata": {},
"outputs": [],
"source": [
"# First we make a chain that generates the list"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2b1884f5",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"import re\n",
"class ListOutputParser(BaseOutputParser):\n",
" \n",
" def __init__(self, regex: Optional[str] = None):\n",
" self.regex=regex\n",
" \n",
" def parse(self, text: str) -> list:\n",
" splits = [t for t in text.split(\"\\n\") if t]\n",
" if self.regex is not None:\n",
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
" return splits"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b2b7f8fa",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
"\n",
"The format of your lists should be:\n",
"\n",
"```\n",
"List:\n",
"- Item 1\n",
"- Item 2\n",
"...\n",
"```\n",
"\n",
"Begin!:\n",
"\n",
"User input: {input}\n",
"List:\"\"\"\n",
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
"\n",
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "2f8ea6ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.predict_and_parse(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1fdfc7cb",
"metadata": {},
"outputs": [],
"source": [
"# Next we generate the chain that we run over each item"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "0b8f115a",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"For the following company, explain the origin of their name:\n",
"\n",
"Company: {company}\n",
"Explanation of their name:\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
"\n",
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "6d636881",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explanation_chain.predict(company=\"Tesla\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "3c236dd3",
"metadata": {},
"outputs": [],
"source": [
"# Now we combine them"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "941b6389",
"metadata": {},
"outputs": [],
"source": [
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "98c39dbc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Nissan\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BMW\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BYD\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Volkswagen\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a2c1803",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,5 +1,5 @@
"""Chain that just formats a prompt and calls an LLM."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra
@@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result

View File

@@ -0,0 +1,63 @@
"""Chain that first uses an LLM to generate multiple items then loops over them."""
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts.base import ListOutputParser
class LLMForLoopChain(Chain, BaseModel):
"""Chain that first uses an LLM to generate multiple items then loops over them."""
llm_chain: LLMChain
"""LLM chain to use to generate multiple items."""
apply_chain: Chain
"""Chain to apply to each item that is generated."""
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.llm_chain.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_output_parser(cls, values: Dict) -> Dict:
"""Validate that the correct inputs exist for all chains."""
chain = values["llm_chain"]
if not isinstance(chain.prompt.output_parser, ListOutputParser):
raise ValueError(
f"The OutputParser on the base prompt should be of type "
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
)
return values
def run_list(self, **kwargs: Any) -> List[str]:
"""Get list from LLM chain and then run chain on each item."""
output_items = self.llm_chain.predict_and_parse(**kwargs)
res = []
for item in output_items:
res.append(self.apply_chain.run(item))
return res
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
res = self.run_list(**inputs)
return {self.output_key: "\n\n".join(res)}

View File

@@ -2,10 +2,10 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, Extra, root_validator
from langchain.formatting import formatter
@@ -32,11 +32,35 @@ def check_valid_template(
raise ValueError("Invalid prompt schema.")
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""
class ListOutputParser(ABC):
"""Class to parse the output of an LLM call to a list."""
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict: