showcase guarded pydantic parsing

This commit is contained in:
jerwelborn
2023-03-20 16:25:35 -07:00
parent 3ee755897e
commit 32a8507829

View File

@@ -75,7 +75,7 @@
{
"data": {
"text/plain": [
"Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')"
"Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')"
]
},
"execution_count": 4,
@@ -124,7 +124,7 @@
{
"data": {
"text/plain": [
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])"
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])"
]
},
"execution_count": 5,
@@ -155,11 +155,354 @@
"parser.parse(output)"
]
},
{
"cell_type": "markdown",
"id": "4d6c0c86",
"metadata": {},
"source": [
"### Aside: adding \"guardrails\" to your parsers.\n",
"\n",
"\"Guardrails\" intuitively add validation logic + optionally retry logic to some black box output, like an LM generating structured output...\n",
"\n",
"Below we'll showcase a \"guarded\" parser which can be dropped into an `LLMChain` as is. It will catch errors at parsing time and try resolve them, initially by re-invoking an LLM. There are many approaches for guardrailing non-deterministic LLMs, here's a simple case."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "39b1a5ce",
"metadata": {},
"outputs": [],
"source": [
"from langchain.guardrails.parsing import RetriableOutputParser\n",
"from langchain.output_parsers import OutputParserException\n",
"\n",
"# Note: here we use an LLMChain which slightly abstracts calling an LLM with prompt templates + parsers.\n",
"from langchain.chains import LLMChain"
]
},
{
"cell_type": "markdown",
"id": "d34ddc44",
"metadata": {},
"source": [
"#### 1st example: retry with a larger model."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "742cee72",
"metadata": {},
"outputs": [],
"source": [
"# Pydantic data structure.\n",
"class FloatArray(BaseModel):\n",
" values: List[float] = Field(description=\"list of floats\")\n",
"\n",
"# Query that will populate the data structure.\n",
"float_array_query = \"Write out a few terms of fiboacci.\""
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a81f121c",
"metadata": {},
"outputs": [],
"source": [
"# Declare a parser and prompt.\n",
"parser = PydanticOutputParser(pydantic_object=FloatArray)\n",
"\n",
"prompt = PromptTemplate(\n",
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
" input_variables=[\"query\"],\n",
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
")\n",
"\n",
"# Currently, the parser is set on the prompt template for use in an LLMChain.\n",
"prompt.output_parser = parser"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "492605f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mAnswer the user query.\n",
"The output should be formatted as a JSON instance that conforms to the JSON schema below.\n",
"\n",
"As an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}}\n",
"the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n",
"\n",
"Here is the output schema:\n",
"```\n",
"{\"properties\": {\"values\": {\"title\": \"Values\", \"description\": \"list of floats\", \"type\": \"array\", \"items\": {\"type\": \"number\"}}}, \"required\": [\"values\"]}\n",
"```\n",
"Write out a few terms of fiboacci.\n",
"\u001b[0m\n",
"Dang!\n",
"Failed to parse FloatArray from completion \n",
"A fibonacci sequence is a sequence of numbers where each number in the sequence is the sum of the two previous numbers in the sequence. The first number in the sequence is 0 and the second number in the sequence is 1. The Fibonacci sequence can be expressed as:\n",
"\n",
"0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89.. Got: Expecting value: line 1 column 1 (char 0)\n"
]
}
],
"source": [
"# For demonstration's sake, we'll use a \"small\" model that probably won't generate json properly.\n",
"llm_chain = LLMChain(\n",
" prompt=prompt,\n",
" llm=OpenAI(model_name=\"text-curie-001\"),\n",
" verbose=True)\n",
"\n",
"try:\n",
" llm_chain.predict(query=float_array_query)\n",
"except OutputParserException as e:\n",
" print(\"Dang!\")\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d2ebac02",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mAnswer the user query.\n",
"The output should be formatted as a JSON instance that conforms to the JSON schema below.\n",
"\n",
"As an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}}\n",
"the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n",
"\n",
"Here is the output schema:\n",
"```\n",
"{\"properties\": {\"values\": {\"title\": \"Values\", \"description\": \"list of floats\", \"type\": \"array\", \"items\": {\"type\": \"number\"}}}, \"required\": [\"values\"]}\n",
"```\n",
"Write out a few terms of fiboacci.\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"FloatArray(values=[0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# We can replace the parser with a guarded parser that tries to fix errors with a bigger model.\n",
"guarded_parser = RetriableOutputParser(\n",
" parser=parser, retry_llm=OpenAI(model_name=\"text-davinci-003\"))\n",
"prompt.output_parser = guarded_parser\n",
"\n",
"llm_chain.predict(query=float_array_query)"
]
},
{
"cell_type": "markdown",
"id": "e58cd77c",
"metadata": {},
"source": [
"This example is demonstrative though. If your goal is to generate data structures, probably you'll want to start a large enough model."
]
},
{
"cell_type": "markdown",
"id": "3bb4e21e",
"metadata": {},
"source": [
"#### 2nd example: a more realistic example that a large model sometimes struggles with."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c18c6cd5",
"metadata": {},
"outputs": [],
"source": [
"from enum import Enum\n",
"\n",
"# These data structure will induce a classification & summarization task. Neat!\n",
"class Outcome(str, Enum):\n",
" Purchase = \"Purchase\"\n",
" Objection = \"Objection\"\n",
" class Config: \n",
" use_enum_values = True\n",
"\n",
"class CustomerOutcome(BaseModel):\n",
" outcome: Outcome = Field(description=\"did the customer purchase or object to the offer\")\n",
" reason_for_outcome: str = Field(description=\"why?\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cbe356e3",
"metadata": {},
"outputs": [],
"source": [
"parser = PydanticOutputParser(pydantic_object=CustomerOutcome)\n",
"\n",
"prompt_template_str = \"\"\"Answer the query below.\n",
"Customer Message: {customer_msg}\n",
"{format_instructions}\n",
"Say whether the Customer accepted or rejected the purchase and summarize why:\"\"\"\n",
"\n",
"prompt = PromptTemplate(\n",
" template=prompt_template_str,\n",
" input_variables=[\"customer_msg\"],\n",
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
")\n",
"prompt.output_parser = parser"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9e331167",
"metadata": {},
"outputs": [],
"source": [
"customer_msg = \"\"\"Nope thats way over budget, can't do it won't do it.\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "30d0b455",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mAnswer the query below.\n",
"Customer Message: Nope thats way over budget, can't do it won't do it.\n",
"The output should be formatted as a JSON instance that conforms to the JSON schema below.\n",
"\n",
"As an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}}\n",
"the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n",
"\n",
"Here is the output schema:\n",
"```\n",
"{\"properties\": {\"outcome\": {\"description\": \"did the customer purchase or object to the offer\", \"allOf\": [{\"$ref\": \"#/definitions/Outcome\"}]}, \"reason_for_outcome\": {\"title\": \"Reason For Outcome\", \"description\": \"why?\", \"type\": \"string\"}}, \"required\": [\"outcome\", \"reason_for_outcome\"], \"definitions\": {\"Outcome\": {\"title\": \"Outcome\", \"description\": \"An enumeration.\", \"enum\": [\"Purchase\", \"Objection\", \"<class '__main__.Outcome.Config'>\"], \"type\": \"string\"}}}\n",
"```\n",
"Say whether the Customer accepted or rejected the purchase and summarize why:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"outcome=<Outcome.Objection: 'Objection'> reason_for_outcome='The customer said the price was too high.'\n"
]
}
],
"source": [
"llm_chain = LLMChain(\n",
" prompt=prompt,\n",
" llm=OpenAI(model_name=\"text-davinci-001\"),\n",
" verbose=True)\n",
"\n",
"try:\n",
" completion = llm_chain.predict(customer_msg=customer_msg)\n",
" print(completion)\n",
"except OutputParserException as e:\n",
" print(\"Dang!\")\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"id": "d4deae7e",
"metadata": {},
"source": [
"Nice! It worked. We can similarly wrap with a retry to be safe:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "2501766a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mAnswer the query below.\n",
"Customer Message: Nope thats way over budget, can't do it won't do it.\n",
"The output should be formatted as a JSON instance that conforms to the JSON schema below.\n",
"\n",
"As an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}}\n",
"the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.\n",
"\n",
"Here is the output schema:\n",
"```\n",
"{\"properties\": {\"outcome\": {\"description\": \"did the customer purchase or object to the offer\", \"allOf\": [{\"$ref\": \"#/definitions/Outcome\"}]}, \"reason_for_outcome\": {\"title\": \"Reason For Outcome\", \"description\": \"why?\", \"type\": \"string\"}}, \"required\": [\"outcome\", \"reason_for_outcome\"], \"definitions\": {\"Outcome\": {\"title\": \"Outcome\", \"description\": \"An enumeration.\", \"enum\": [\"Purchase\", \"Objection\", \"<class '__main__.Outcome.Config'>\"], \"type\": \"string\"}}}\n",
"```\n",
"Say whether the Customer accepted or rejected the purchase and summarize why:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"CustomerOutcome(outcome=<Outcome.Objection: 'Objection'>, reason_for_outcome='The customer said that the product was over budget.')"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"guarded_parser = RetriableOutputParser(\n",
" parser=parser, retry_llm=OpenAI(model_name=\"text-davinci-003\"))\n",
"prompt.output_parser = guarded_parser\n",
"\n",
"llm_chain.predict(customer_msg=customer_msg)"
]
},
{
"cell_type": "markdown",
"id": "61f67890",
"metadata": {},
"source": [
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
@@ -168,6 +511,14 @@
"---"
]
},
{
"cell_type": "markdown",
"id": "64bf525a",
"metadata": {},
"source": [
"# Older, less powerful parsers"
]
},
{
"cell_type": "markdown",
"id": "91871002",
@@ -180,7 +531,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 16,
"id": "b492997a",
"metadata": {},
"outputs": [],
@@ -198,7 +549,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 17,
"id": "432ac44a",
"metadata": {},
"outputs": [],
@@ -220,7 +571,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 18,
"id": "593cfc25",
"metadata": {},
"outputs": [],
@@ -243,7 +594,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 19,
"id": "106f1ba6",
"metadata": {},
"outputs": [],
@@ -253,7 +604,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 20,
"id": "86d9d24f",
"metadata": {},
"outputs": [],
@@ -264,7 +615,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 21,
"id": "956bdc99",
"metadata": {},
"outputs": [
@@ -274,7 +625,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 11,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -293,7 +644,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 22,
"id": "8f483d7d",
"metadata": {},
"outputs": [],
@@ -303,7 +654,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 23,
"id": "f761cbf1",
"metadata": {},
"outputs": [],
@@ -319,7 +670,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 24,
"id": "edd73ae3",
"metadata": {},
"outputs": [],
@@ -330,7 +681,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 25,
"id": "a3c8b91e",
"metadata": {},
"outputs": [
@@ -340,7 +691,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 15,
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
@@ -361,7 +712,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 26,
"id": "872246d7",
"metadata": {},
"outputs": [],
@@ -371,7 +722,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 27,
"id": "c3f9aee6",
"metadata": {},
"outputs": [],
@@ -381,7 +732,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 28,
"id": "e77871b7",
"metadata": {},
"outputs": [],
@@ -396,7 +747,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 29,
"id": "a71cb5d3",
"metadata": {},
"outputs": [],
@@ -406,7 +757,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 30,
"id": "783d7d98",
"metadata": {},
"outputs": [],
@@ -417,7 +768,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 31,
"id": "fcb81344",
"metadata": {},
"outputs": [
@@ -431,7 +782,7 @@
" 'Cookies and Cream']"
]
},
"execution_count": 21,
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}