Compare commits

...

14 Commits

Author SHA1 Message Date
Harrison Chase
3c96d261d1 cr 2023-03-21 07:54:58 -07:00
Harrison Chase
81b87a6c20 add debug mode 2023-03-21 07:54:11 -07:00
Harrison Chase
86085bc1e4 cr 2023-03-20 19:11:08 -07:00
Harrison Chase
5f41f07b8b Merge branch 'master' into harrison/guarded-output-parser 2023-03-20 19:10:51 -07:00
Harrison Chase
ccc18973b4 cr 2023-03-20 19:10:13 -07:00
jerwelborn
32a8507829 showcase guarded pydantic parsing 2023-03-20 16:25:35 -07:00
jerwelborn
3ee755897e tweak pydantic parser 2023-03-20 16:25:18 -07:00
jerwelborn
a0cde05839 try make guarded/retriable output parser an instance of parser 2023-03-20 13:58:12 -07:00
jerwelborn
325825d55f add example nb 2023-03-20 12:54:21 -07:00
jerwelborn
bfa858b3a6 make parser and guarded parser roughly swappable 2023-03-20 12:41:57 -07:00
jerwelborn
fa2d98c487 factor out 'naive' retry chain 2023-03-20 11:27:04 -07:00
Harrison Chase
6898d8391f cr 2023-03-19 18:05:37 -07:00
Harrison Chase
1af560cca8 cr 2023-03-19 18:00:23 -07:00
Harrison Chase
44d2492427 guarded output parser 2023-03-19 17:57:42 -07:00
20 changed files with 637 additions and 117 deletions

View File

@@ -14,12 +14,17 @@
"- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n",
"- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n",
"\n",
"And then one optional one:\n",
"\n",
"- `parse_with_prompt(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and a prompt (assumed to the prompt that generated such a response) and parses it into some structure. The prompt is largely provided in the event the OutputParser wants to retry or fix the output in some way, and needs information from the prompt to do so.\n",
"\n",
"\n",
"Below we go over some examples of output parsers."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "5f0c8a33",
"metadata": {},
"outputs": [],
@@ -44,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "cba6d8e3",
"metadata": {},
"outputs": [],
@@ -56,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "0a203100",
"metadata": {},
"outputs": [],
@@ -68,17 +73,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "b3f16168",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')"
"Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -117,17 +122,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "03049f88",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])"
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -155,11 +160,297 @@
"parser.parse(output)"
]
},
{
"cell_type": "markdown",
"id": "4d6c0c86",
"metadata": {},
"source": [
"## Fixing Output Parsing Mistakes\n",
"\n",
"The above guardrail simply tries to parse the LLM response. If it does not parse correctly, then it errors.\n",
"\n",
"But we can do other things besides throw errors. Specifically, we can pass the misformatted output, along with the formatted instructions, to the model and ask it to fix it.\n",
"\n",
"For this example, we'll use the above OutputParser. Here's what happens if we pass it a result that does not comply with the schema:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "73beb20d",
"metadata": {},
"outputs": [],
"source": [
"misformatted = \"{'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}\""
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f0e5ba80",
"metadata": {},
"outputs": [
{
"ename": "OutputParserException",
"evalue": "Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:23\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m json_str \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup()\n\u001b[0;32m---> 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39mparse_obj(json_object)\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmisformatted\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"
]
}
],
"source": [
"parser.parse(misformatted)"
]
},
{
"cell_type": "markdown",
"id": "6c7c82b6",
"metadata": {},
"source": [
"Now we can construct and use a `FixOutputParser`. This output parser takes as an argument another output parser but also an LLM with which to try to correct any formatting mistakes."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "39b1a5ce",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import FixOutputParser\n",
"\n",
"new_parser = FixOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0fd96d68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Action(action='AddFilm', action_input='Forrest Gump')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_parser.parse(misformatted)"
]
},
{
"cell_type": "markdown",
"id": "ea34eeaa",
"metadata": {},
"source": [
"## Fixing Output Parsing Mistakes with the original prompt\n",
"\n",
"While in some cases it is possible to fix any parsing mistakes by only looking at the output, in other cases it can't. An example of this is when the output is not just in the incorrect format, but is partially complete. Consider the below example."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "67c5e1ac",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the user question, provide an Action and Action Input for what step should be taken.\n",
"{format_instructions}\n",
"Question: {query}\n",
"Response:\"\"\"\n",
"class Action(BaseModel):\n",
" action: str = Field(description=\"action to take\")\n",
" action_input: str = Field(description=\"input to the action\")\n",
" \n",
"parser = PydanticOutputParser(pydantic_object=Action)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "007aa87f",
"metadata": {},
"outputs": [],
"source": [
"prompt = PromptTemplate(\n",
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
" input_variables=[\"query\"],\n",
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "10d207ff",
"metadata": {},
"outputs": [],
"source": [
"prompt_value = prompt.format_prompt(query=\"who is leo di caprios gf?\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "68622837",
"metadata": {},
"outputs": [],
"source": [
"bad_response = '{\"action\": \"search\"}'"
]
},
{
"cell_type": "markdown",
"id": "25631465",
"metadata": {},
"source": [
"If we try to parse this response as is, we will get an error"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "894967c1",
"metadata": {},
"outputs": [
{
"ename": "OutputParserException",
"evalue": "Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:24\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(json_str)\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydantic_object\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_object\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (json\u001b[38;5;241m.\u001b[39mJSONDecodeError, ValidationError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:527\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:342\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbad_response\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)"
]
}
],
"source": [
"parser.parse(bad_response)"
]
},
{
"cell_type": "markdown",
"id": "f6b64696",
"metadata": {},
"source": [
"If we try to use the FixOutputParser to fix this error, it will be confused - namely, it doesn't know what to actually put for action input."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "78b2b40d",
"metadata": {},
"outputs": [],
"source": [
"fix_parser = FixOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4fe1301d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Action(action='search', action_input='query')"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fix_parser.parse(bad_response)"
]
},
{
"cell_type": "markdown",
"id": "9bd9ea7d",
"metadata": {},
"source": [
"Instead, we can use the RetryOutputParser, which passes in the prompt (as well as the original output) to try again to get a better response."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7e8a8a28",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import RetryWithErrorOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5c86e141",
"metadata": {},
"outputs": [],
"source": [
"retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9c04f731",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Action(action='search', action_input='leo di caprios girlfriend')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retry_parser.parse_with_prompt(bad_response, prompt_value)"
]
},
{
"cell_type": "markdown",
"id": "61f67890",
"metadata": {},
"source": [
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
@@ -168,6 +459,14 @@
"---"
]
},
{
"cell_type": "markdown",
"id": "64bf525a",
"metadata": {},
"source": [
"# Older, less powerful parsers"
]
},
{
"cell_type": "markdown",
"id": "91871002",
@@ -180,7 +479,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 16,
"id": "b492997a",
"metadata": {},
"outputs": [],
@@ -198,7 +497,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 17,
"id": "432ac44a",
"metadata": {},
"outputs": [],
@@ -220,7 +519,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 18,
"id": "593cfc25",
"metadata": {},
"outputs": [],
@@ -243,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 19,
"id": "106f1ba6",
"metadata": {},
"outputs": [],
@@ -253,7 +552,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 20,
"id": "86d9d24f",
"metadata": {},
"outputs": [],
@@ -264,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 21,
"id": "956bdc99",
"metadata": {},
"outputs": [
@@ -274,7 +573,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 11,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -293,7 +592,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 22,
"id": "8f483d7d",
"metadata": {},
"outputs": [],
@@ -303,7 +602,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 23,
"id": "f761cbf1",
"metadata": {},
"outputs": [],
@@ -319,7 +618,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 24,
"id": "edd73ae3",
"metadata": {},
"outputs": [],
@@ -330,7 +629,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 25,
"id": "a3c8b91e",
"metadata": {},
"outputs": [
@@ -340,7 +639,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 15,
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
@@ -361,7 +660,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 26,
"id": "872246d7",
"metadata": {},
"outputs": [],
@@ -371,7 +670,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 27,
"id": "c3f9aee6",
"metadata": {},
"outputs": [],
@@ -381,7 +680,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 28,
"id": "e77871b7",
"metadata": {},
"outputs": [],
@@ -396,7 +695,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 29,
"id": "a71cb5d3",
"metadata": {},
"outputs": [],
@@ -406,7 +705,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 30,
"id": "783d7d98",
"metadata": {},
"outputs": [],
@@ -417,7 +716,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 31,
"id": "fcb81344",
"metadata": {},
"outputs": [
@@ -431,7 +730,7 @@
" 'Cookies and Cream']"
]
},
"execution_count": 21,
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
@@ -457,7 +756,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
@@ -26,6 +25,7 @@ from langchain.schema import (
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,
)
from langchain.tools.base import BaseTool

View File

@@ -52,7 +52,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
result, _ = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)

View File

@@ -1,15 +1,22 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
Generation,
LLMResult,
PromptValue,
)
class LLMChain(Chain, BaseModel):
@@ -30,6 +37,26 @@ class LLMChain(Chain, BaseModel):
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
output_parser: Optional[BaseOutputParser] = None
debug: bool = False
@root_validator()
def validate_output_parser(cls, values: Dict) -> Dict:
"""Validate output parser."""
prompt: BasePromptTemplate = values["prompt"]
output_parser = values["output_parser"]
if prompt.output_parser is not None:
warnings.warn(
"Got an output parser on the prompt "
"- please transition this to being passed into the LLMChain."
)
if output_parser is not None:
raise ValueError(
"Got an output parser on the prompt as well on the LLMChain - "
"should only be provided in one place."
)
values["output_parser"] = prompt.output_parser
return values
class Config:
"""Configuration for this pydantic object."""
@@ -51,20 +78,27 @@ class LLMChain(Chain, BaseModel):
:meta private:
"""
return [self.output_key]
if not self.debug:
return [self.output_key]
else:
return [self.output_key, "raw", "error"]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
def generate(
self, input_list: List[Dict[str, Any]]
) -> Tuple[LLMResult, List[PromptValue]]:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
return self.llm.generate_prompt(prompts, stop)
return self.llm.generate_prompt(prompts, stop), prompts
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
async def agenerate(
self, input_list: List[Dict[str, Any]]
) -> Tuple[LLMResult, List[PromptValue]]:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
return await self.llm.agenerate_prompt(prompts, stop)
return await self.llm.agenerate_prompt(prompts, stop), prompts
def prep_prompts(
self, input_list: List[Dict[str, Any]]
@@ -115,20 +149,47 @@ class LLMChain(Chain, BaseModel):
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
response, prompts = self.generate(input_list)
return self.create_outputs(response, prompts)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
response, prompts = await self.agenerate(input_list)
return self.create_outputs(response, prompts)
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
def _get_final_output(
self, generations: List[Generation], prompt_value: PromptValue
) -> Dict:
"""Get the final output from a list of generations for a prompt."""
completion = generations[0].text
if self.output_parser is not None:
try:
new_completion = self.output_parser.parse_with_prompt(
completion, prompt_value
)
result = {self.output_key: new_completion}
if self.debug:
result["raw"] = completion
result["errors"] = []
except Exception as e:
if self.debug:
result = {
self.output_key: None,
"raw": completion,
"error": [repr(e)],
}
else:
result = {self.output_key: completion}
return result
def create_outputs(
self, response: LLMResult, prompts: List[PromptValue]
) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
self._get_final_output(generation, prompts[i])
for i, generation in enumerate(response.generations)
]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
@@ -166,37 +227,23 @@ class LLMChain(Chain, BaseModel):
"""
return (await self.acall(kwargs))[self.output_key]
# If an output_parser is provided, it should always be applied.
# TODO: remove these methods.
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
return self.predict(**kwargs)
def apply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
return self._parse_result(result)
def _parse_result(
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key]) for res in result
]
else:
return result
return self.apply(input_list)
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
return self._parse_result(result)
return await self.aapply(input_list)
@property
def _chain_type(self) -> str:

View File

@@ -47,7 +47,7 @@ class QAGenerationChain(Chain):
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
results, _ = self.llm_chain.generate([{"text": d.page_content} for d in docs])
qa = [json.loads(res[0].text) for res in results.generations]
return {self.output_key: qa}

View File

@@ -1,4 +1,4 @@
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.fix import FixOutputParser
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
@@ -7,6 +7,7 @@ from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
__all__ = [
@@ -14,9 +15,11 @@ __all__ = [
"RegexDictParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
"StructuredOutputParser",
"ResponseSchema",
"GuardrailsOutputParser",
"PydanticOutputParser",
"RetryOutputParser",
"RetryWithErrorOutputParser",
"FixOutputParser",
]

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from pydantic import BaseModel
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import Any
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
class FixOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> FixOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse(self, completion: str) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

View File

@@ -8,7 +8,10 @@ STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet
}}
```"""
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}.
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod
from typing import List
from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser
class ListOutputParser(BaseOutputParser):

View File

@@ -0,0 +1,22 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
NAIVE_FIX = """Instructions:
--------------
{instructions}
--------------
Completion:
--------------
{completion}
--------------
Above, the Completion did not satisfy the constraints given in the Instructions.
Error:
--------------
{error}
--------------
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)

View File

@@ -4,8 +4,8 @@ from typing import Any
from pydantic import BaseModel, ValidationError
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
class PydanticOutputParser(BaseOutputParser):
@@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser):
def parse(self, text: str) -> BaseModel:
try:
# Greedy search for 1st json candidate.
match = re.search("\{.*\}", text.strip())
match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
@@ -24,16 +26,17 @@ class PydanticOutputParser(BaseOutputParser):
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise ValueError(msg)
raise OutputParserException(msg)
def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()
# Remove extraneous fields.
reduced_schema = {
prop: {"description": data["description"], "type": data["type"]}
for prop, data in schema["properties"].items()
}
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict
from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser
class GuardrailsOutputParser(BaseOutputParser):

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser
class RegexParser(BaseOutputParser, BaseModel):

View File

@@ -5,7 +5,7 @@ from typing import Dict, Optional
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser
class RegexDictParser(BaseOutputParser, BaseModel):

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
from typing import Any
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
OutputParserException,
PromptValue,
)
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Please try again:"""
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Details: {error}
Please try again:"""
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)
class RetryOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
class RetryWithErrorOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

View File

@@ -5,8 +5,8 @@ from typing import List
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
line_template = '\t"{name}": {type} // {description}'
@@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser):
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise ValueError(
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)

View File

@@ -10,13 +10,7 @@ import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.formatting import formatter
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.list import ( # noqa: F401
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser # noqa: F401
from langchain.schema import BaseMessage, HumanMessage, PromptValue
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
def jinja2_formatter(template: str, **kwargs: Any) -> str:

View File

@@ -259,3 +259,32 @@ class BaseMemory(BaseModel, ABC):
Memory = BaseMemory
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
return self.parse(completion)
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputParserException(Exception):
pass

View File

@@ -7,8 +7,8 @@ import pytest
from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
from tests.unit_tests.llms.fake_llm import FakeLLM