mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
Compare commits
14 Commits
langchain-
...
harrison/d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c96d261d1 | ||
|
|
81b87a6c20 | ||
|
|
86085bc1e4 | ||
|
|
5f41f07b8b | ||
|
|
ccc18973b4 | ||
|
|
32a8507829 | ||
|
|
3ee755897e | ||
|
|
a0cde05839 | ||
|
|
325825d55f | ||
|
|
bfa858b3a6 | ||
|
|
fa2d98c487 | ||
|
|
6898d8391f | ||
|
|
1af560cca8 | ||
|
|
44d2492427 |
@@ -14,12 +14,17 @@
|
||||
"- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n",
|
||||
"- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n",
|
||||
"\n",
|
||||
"And then one optional one:\n",
|
||||
"\n",
|
||||
"- `parse_with_prompt(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and a prompt (assumed to the prompt that generated such a response) and parses it into some structure. The prompt is largely provided in the event the OutputParser wants to retry or fix the output in some way, and needs information from the prompt to do so.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Below we go over some examples of output parsers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "5f0c8a33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -44,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "cba6d8e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -56,7 +61,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "0a203100",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -68,17 +73,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "b3f16168",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')"
|
||||
"Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -117,17 +122,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "03049f88",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])"
|
||||
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -155,11 +160,297 @@
|
||||
"parser.parse(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d6c0c86",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fixing Output Parsing Mistakes\n",
|
||||
"\n",
|
||||
"The above guardrail simply tries to parse the LLM response. If it does not parse correctly, then it errors.\n",
|
||||
"\n",
|
||||
"But we can do other things besides throw errors. Specifically, we can pass the misformatted output, along with the formatted instructions, to the model and ask it to fix it.\n",
|
||||
"\n",
|
||||
"For this example, we'll use the above OutputParser. Here's what happens if we pass it a result that does not comply with the schema:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "73beb20d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"misformatted = \"{'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "f0e5ba80",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "OutputParserException",
|
||||
"evalue": "Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:23\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m json_str \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup()\n\u001b[0;32m---> 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39mparse_obj(json_object)\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
||||
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmisformatted\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser.parse(misformatted)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6c7c82b6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can construct and use a `FixOutputParser`. This output parser takes as an argument another output parser but also an LLM with which to try to correct any formatting mistakes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "39b1a5ce",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers import FixOutputParser\n",
|
||||
"\n",
|
||||
"new_parser = FixOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "0fd96d68",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Action(action='AddFilm', action_input='Forrest Gump')"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_parser.parse(misformatted)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea34eeaa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fixing Output Parsing Mistakes with the original prompt\n",
|
||||
"\n",
|
||||
"While in some cases it is possible to fix any parsing mistakes by only looking at the output, in other cases it can't. An example of this is when the output is not just in the incorrect format, but is partially complete. Consider the below example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "67c5e1ac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Based on the user question, provide an Action and Action Input for what step should be taken.\n",
|
||||
"{format_instructions}\n",
|
||||
"Question: {query}\n",
|
||||
"Response:\"\"\"\n",
|
||||
"class Action(BaseModel):\n",
|
||||
" action: str = Field(description=\"action to take\")\n",
|
||||
" action_input: str = Field(description=\"input to the action\")\n",
|
||||
" \n",
|
||||
"parser = PydanticOutputParser(pydantic_object=Action)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "007aa87f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = PromptTemplate(\n",
|
||||
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
|
||||
" input_variables=[\"query\"],\n",
|
||||
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "10d207ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt_value = prompt.format_prompt(query=\"who is leo di caprios gf?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "68622837",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"bad_response = '{\"action\": \"search\"}'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25631465",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we try to parse this response as is, we will get an error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "894967c1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "OutputParserException",
|
||||
"evalue": "Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:24\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(json_str)\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydantic_object\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_object\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (json\u001b[38;5;241m.\u001b[39mJSONDecodeError, ValidationError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:527\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
|
||||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:342\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
|
||||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbad_response\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
|
||||
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser.parse(bad_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f6b64696",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we try to use the FixOutputParser to fix this error, it will be confused - namely, it doesn't know what to actually put for action input."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "78b2b40d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fix_parser = FixOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "4fe1301d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Action(action='search', action_input='query')"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fix_parser.parse(bad_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9bd9ea7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Instead, we can use the RetryOutputParser, which passes in the prompt (as well as the original output) to try again to get a better response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "7e8a8a28",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.output_parsers import RetryWithErrorOutputParser"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "5c86e141",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "9c04f731",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Action(action='search', action_input='leo di caprios girlfriend')"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retry_parser.parse_with_prompt(bad_response, prompt_value)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61f67890",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
"<br>\n",
|
||||
@@ -168,6 +459,14 @@
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64bf525a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Older, less powerful parsers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91871002",
|
||||
@@ -180,7 +479,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 16,
|
||||
"id": "b492997a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -198,7 +497,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 17,
|
||||
"id": "432ac44a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -220,7 +519,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 18,
|
||||
"id": "593cfc25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -243,7 +542,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 19,
|
||||
"id": "106f1ba6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -253,7 +552,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 20,
|
||||
"id": "86d9d24f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -264,7 +563,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 21,
|
||||
"id": "956bdc99",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -274,7 +573,7 @@
|
||||
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -293,7 +592,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 22,
|
||||
"id": "8f483d7d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -303,7 +602,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 23,
|
||||
"id": "f761cbf1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -319,7 +618,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 24,
|
||||
"id": "edd73ae3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -330,7 +629,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 25,
|
||||
"id": "a3c8b91e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -340,7 +639,7 @@
|
||||
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -361,7 +660,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 26,
|
||||
"id": "872246d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -371,7 +670,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 27,
|
||||
"id": "c3f9aee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -381,7 +680,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 28,
|
||||
"id": "e77871b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -396,7 +695,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 29,
|
||||
"id": "a71cb5d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -406,7 +705,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 30,
|
||||
"id": "783d7d98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -417,7 +716,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 31,
|
||||
"id": "fcb81344",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -431,7 +730,7 @@
|
||||
" 'Cookies and Cream']"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -457,7 +756,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.0"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@@ -26,6 +25,7 @@ from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
@@ -52,7 +52,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate a hypothetical document and embedded it."""
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
result, _ = self.llm_chain.generate([{var_name: text}])
|
||||
documents = [generation.text for generation in result.generations[0]]
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseOutputParser,
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
)
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@@ -30,6 +37,26 @@ class LLMChain(Chain, BaseModel):
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
debug: bool = False
|
||||
|
||||
@root_validator()
|
||||
def validate_output_parser(cls, values: Dict) -> Dict:
|
||||
"""Validate output parser."""
|
||||
prompt: BasePromptTemplate = values["prompt"]
|
||||
output_parser = values["output_parser"]
|
||||
if prompt.output_parser is not None:
|
||||
warnings.warn(
|
||||
"Got an output parser on the prompt "
|
||||
"- please transition this to being passed into the LLMChain."
|
||||
)
|
||||
if output_parser is not None:
|
||||
raise ValueError(
|
||||
"Got an output parser on the prompt as well on the LLMChain - "
|
||||
"should only be provided in one place."
|
||||
)
|
||||
values["output_parser"] = prompt.output_parser
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -51,20 +78,27 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
if not self.debug:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "raw", "error"]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.apply([inputs])[0]
|
||||
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
def generate(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Tuple[LLMResult, List[PromptValue]]:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
return self.llm.generate_prompt(prompts, stop)
|
||||
return self.llm.generate_prompt(prompts, stop), prompts
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
async def agenerate(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Tuple[LLMResult, List[PromptValue]]:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
return await self.llm.agenerate_prompt(prompts, stop)
|
||||
return await self.llm.agenerate_prompt(prompts, stop), prompts
|
||||
|
||||
def prep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
@@ -115,20 +149,47 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = self.generate(input_list)
|
||||
return self.create_outputs(response)
|
||||
response, prompts = self.generate(input_list)
|
||||
return self.create_outputs(response, prompts)
|
||||
|
||||
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = await self.agenerate(input_list)
|
||||
return self.create_outputs(response)
|
||||
response, prompts = await self.agenerate(input_list)
|
||||
return self.create_outputs(response, prompts)
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
def _get_final_output(
|
||||
self, generations: List[Generation], prompt_value: PromptValue
|
||||
) -> Dict:
|
||||
"""Get the final output from a list of generations for a prompt."""
|
||||
completion = generations[0].text
|
||||
if self.output_parser is not None:
|
||||
try:
|
||||
new_completion = self.output_parser.parse_with_prompt(
|
||||
completion, prompt_value
|
||||
)
|
||||
result = {self.output_key: new_completion}
|
||||
if self.debug:
|
||||
result["raw"] = completion
|
||||
result["errors"] = []
|
||||
except Exception as e:
|
||||
if self.debug:
|
||||
result = {
|
||||
self.output_key: None,
|
||||
"raw": completion,
|
||||
"error": [repr(e)],
|
||||
}
|
||||
else:
|
||||
result = {self.output_key: completion}
|
||||
return result
|
||||
|
||||
def create_outputs(
|
||||
self, response: LLMResult, prompts: List[PromptValue]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
self._get_final_output(generation, prompts[i])
|
||||
for i, generation in enumerate(response.generations)
|
||||
]
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
@@ -166,37 +227,23 @@ class LLMChain(Chain, BaseModel):
|
||||
"""
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
|
||||
# If an output_parser is provided, it should always be applied.
|
||||
# TODO: remove these methods.
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
return self.predict(**kwargs)
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key]) for res in result
|
||||
]
|
||||
else:
|
||||
return result
|
||||
return self.apply(input_list)
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = await self.aapply(input_list)
|
||||
return self._parse_result(result)
|
||||
return await self.aapply(input_list)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -47,7 +47,7 @@ class QAGenerationChain(Chain):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
||||
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
|
||||
results, _ = self.llm_chain.generate([{"text": d.page_content} for d in docs])
|
||||
qa = [json.loads(res[0].text) for res in results.generations]
|
||||
return {self.output_key: qa}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.fix import FixOutputParser
|
||||
from langchain.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
@@ -7,6 +7,7 @@ from langchain.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.output_parsers.regex_dict import RegexDictParser
|
||||
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
|
||||
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
|
||||
|
||||
__all__ = [
|
||||
@@ -14,9 +15,11 @@ __all__ = [
|
||||
"RegexDictParser",
|
||||
"ListOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"BaseOutputParser",
|
||||
"StructuredOutputParser",
|
||||
"ResponseSchema",
|
||||
"GuardrailsOutputParser",
|
||||
"PydanticOutputParser",
|
||||
"RetryOutputParser",
|
||||
"RetryWithErrorOutputParser",
|
||||
"FixOutputParser",
|
||||
]
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Any:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict()
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
41
langchain/output_parsers/fix.py
Normal file
41
langchain/output_parsers/fix.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class FixOutputParser(BaseOutputParser):
|
||||
"""Wraps a parser and tries to fix parsing errors."""
|
||||
|
||||
parser: BaseOutputParser
|
||||
retry_chain: LLMChain
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser,
|
||||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
||||
) -> FixOutputParser:
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
|
||||
def parse(self, completion: str) -> Any:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = self.retry_chain.run(
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
|
||||
return parsed_completion
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
@@ -8,7 +8,10 @@ STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet
|
||||
}}
|
||||
```"""
|
||||
|
||||
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}.
|
||||
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
|
||||
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
|
||||
|
||||
Here is the output schema:
|
||||
```
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class ListOutputParser(BaseOutputParser):
|
||||
|
||||
22
langchain/output_parsers/prompts.py
Normal file
22
langchain/output_parsers/prompts.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
NAIVE_FIX = """Instructions:
|
||||
--------------
|
||||
{instructions}
|
||||
--------------
|
||||
Completion:
|
||||
--------------
|
||||
{completion}
|
||||
--------------
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Instructions.
|
||||
Error:
|
||||
--------------
|
||||
{error}
|
||||
--------------
|
||||
|
||||
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""
|
||||
|
||||
|
||||
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)
|
||||
@@ -4,8 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class PydanticOutputParser(BaseOutputParser):
|
||||
@@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser):
|
||||
def parse(self, text: str) -> BaseModel:
|
||||
try:
|
||||
# Greedy search for 1st json candidate.
|
||||
match = re.search("\{.*\}", text.strip())
|
||||
match = re.search(
|
||||
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
json_str = ""
|
||||
if match:
|
||||
json_str = match.group()
|
||||
@@ -24,16 +26,17 @@ class PydanticOutputParser(BaseOutputParser):
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
|
||||
raise ValueError(msg)
|
||||
raise OutputParserException(msg)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
schema = self.pydantic_object.schema()
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = {
|
||||
prop: {"description": data["description"], "type": data["type"]}
|
||||
for prop, data in schema["properties"].items()
|
||||
}
|
||||
reduced_schema = schema
|
||||
if "title" in reduced_schema:
|
||||
del reduced_schema["title"]
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure json in context is well-formed with double quotes.
|
||||
schema = json.dumps(reduced_schema)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class GuardrailsOutputParser(BaseOutputParser):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class RegexParser(BaseOutputParser, BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class RegexDictParser(BaseOutputParser, BaseModel):
|
||||
|
||||
107
langchain/output_parsers/retry.py
Normal file
107
langchain/output_parsers/retry.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseOutputParser,
|
||||
OutputParserException,
|
||||
PromptValue,
|
||||
)
|
||||
|
||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||
{prompt}
|
||||
Completion:
|
||||
{completion}
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Prompt.
|
||||
Please try again:"""
|
||||
|
||||
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
|
||||
{prompt}
|
||||
Completion:
|
||||
{completion}
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Prompt.
|
||||
Details: {error}
|
||||
Please try again:"""
|
||||
|
||||
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
|
||||
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
|
||||
NAIVE_COMPLETION_RETRY_WITH_ERROR
|
||||
)
|
||||
|
||||
|
||||
class RetryOutputParser(BaseOutputParser):
|
||||
"""Wraps a parser and tries to fix parsing errors."""
|
||||
|
||||
parser: BaseOutputParser
|
||||
retry_chain: LLMChain
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser,
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
||||
) -> RetryOutputParser:
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
|
||||
return parsed_completion
|
||||
|
||||
def parse(self, completion: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
"This OutputParser can only be called by the `parse_with_prompt` method."
|
||||
)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
|
||||
class RetryWithErrorOutputParser(BaseOutputParser):
|
||||
"""Wraps a parser and tries to fix parsing errors."""
|
||||
|
||||
parser: BaseOutputParser
|
||||
retry_chain: LLMChain
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser,
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||
) -> RetryWithErrorOutputParser:
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
|
||||
return parsed_completion
|
||||
|
||||
def parse(self, completion: str) -> Any:
|
||||
raise NotImplementedError(
|
||||
"This OutputParser can only be called by the `parse_with_prompt` method."
|
||||
)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
@@ -5,8 +5,8 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
line_template = '\t"{name}": {type} // {description}'
|
||||
|
||||
@@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser):
|
||||
json_obj = json.loads(json_string)
|
||||
for schema in self.response_schemas:
|
||||
if schema.name not in json_obj:
|
||||
raise ValueError(
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{schema.name}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
|
||||
@@ -10,13 +10,7 @@ import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.output_parsers.list import ( # noqa: F401
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
)
|
||||
from langchain.output_parsers.regex import RegexParser # noqa: F401
|
||||
from langchain.schema import BaseMessage, HumanMessage, PromptValue
|
||||
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
@@ -259,3 +259,32 @@ class BaseMemory(BaseModel, ABC):
|
||||
|
||||
|
||||
Memory = BaseMemory
|
||||
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Any:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict()
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class OutputParserException(Exception):
|
||||
pass
|
||||
|
||||
@@ -7,8 +7,8 @@ import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from langchain.output_parsers.base import BaseOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user