try make guarded/retriable output parser an instance of parser

This commit is contained in:
jerwelborn
2023-03-20 13:58:12 -07:00
parent 325825d55f
commit a0cde05839
5 changed files with 35 additions and 52 deletions

View File

@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "10d915e4",
"id": "4e75cce5",
"metadata": {},
"outputs": [],
"source": [
@@ -13,7 +13,7 @@
},
{
"cell_type": "markdown",
"id": "c09a2430",
"id": "a930f49b",
"metadata": {},
"source": [
"# Guarded Output Parsers\n",
@@ -22,15 +22,15 @@
"\n",
"Unfortunately, small models to date don't have capacity to generate well-formed, schema-adherent json while large models still sometimes fail.\n",
"\n",
"In this notebook, we showcase a `GuardedOutputParser` which can be dropped in for an `OutputParser` in an `LLMChain`. It will catch errors at parsing time and try resolve them, initially by re-invoking an LLM.\n",
"In this notebook, we showcase a \"Guarded\" `OutputParser` which can be dropped in for an `OutputParser` in an `LLMChain`. It will catch errors at parsing time and try resolve them, initially by re-invoking an LLM\n",
"\n",
"Below are some examples:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3db24806",
"execution_count": 1,
"id": "8e6bd7ef",
"metadata": {},
"outputs": [],
"source": [
@@ -38,7 +38,7 @@
"from typing import List\n",
"\n",
"from langchain.chains import LLMChain\n",
"from langchain.guardrails.parsing import GuardedOutputParser\n",
"from langchain.guardrails.parsing import RetriableOutputParser\n",
"from langchain.llms import OpenAI\n",
"from langchain.output_parsers import PydanticOutputParser, OutputParserException\n",
"from langchain.prompts import PromptTemplate"
@@ -46,7 +46,7 @@
},
{
"cell_type": "markdown",
"id": "8d40796c",
"id": "8bc39653",
"metadata": {},
"source": [
"## 1st example"
@@ -54,8 +54,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"id": "eec61f2d",
"execution_count": 2,
"id": "01337598",
"metadata": {},
"outputs": [],
"source": [
@@ -69,8 +69,8 @@
},
{
"cell_type": "code",
"execution_count": 4,
"id": "42b8f744",
"execution_count": 3,
"id": "618dca6a",
"metadata": {},
"outputs": [],
"source": [
@@ -89,8 +89,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "67a1d356",
"execution_count": 4,
"id": "6ccff983",
"metadata": {},
"outputs": [
{
@@ -112,7 +112,7 @@
"\u001b[0m\n",
"Dang!\n",
"Failed to parse FloatArray from completion \n",
"Fiboacci is a sequence of numbers defined by the Fibonacci sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21.. Got: Expecting value: line 1 column 1 (char 0)\n"
"Fiboacci is a sequence of numbers that are the sum of the previous two numbers in the sequence.. Got: Expecting value: line 1 column 1 (char 0)\n"
]
}
],
@@ -132,8 +132,8 @@
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c568ebbb",
"execution_count": 5,
"id": "3c68ad59",
"metadata": {},
"outputs": [
{
@@ -160,17 +160,17 @@
{
"data": {
"text/plain": [
"FloatArray(values=[0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0, 144.0])"
"FloatArray(values=[1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0])"
]
},
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# We can replace the parser with a guarded parser that tries to fix errors with a bigger model.\n",
"guarded_parser = GuardedOutputParser(\n",
"guarded_parser = RetriableOutputParser(\n",
" parser=parser, retry_llm=OpenAI(model_name=\"text-davinci-003\"))\n",
"prompt.output_parser = guarded_parser\n",
"\n",
@@ -179,27 +179,11 @@
},
{
"cell_type": "markdown",
"id": "aa0e7b3c",
"id": "f5fc247e",
"metadata": {},
"source": [
"This example is demonstrative though. If your goal is to generate data structures, probably you'll want to start a large enough model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3edbc886",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d83cfb61",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -15,8 +15,6 @@ from langchain.schema import (
LLMResult,
PromptValue,
)
from langchain.output_parsers import BaseOutputParser
from langchain.guardrails.parsing import GuardedOutputParser
class LLMChain(Chain, BaseModel):
@@ -140,13 +138,7 @@ class LLMChain(Chain, BaseModel):
"""Get the final output from a list of generations for a prompt."""
completion = generations[0].text
if self.prompt.output_parser:
parser = self.prompt.output_parser
if isinstance(parser, BaseOutputParser):
completion = parser.parse(completion)
elif isinstance(parser, GuardedOutputParser):
# TODO: not ideal to hide retry calls from user. can we expose it in colored log?
completion = parser.parse(prompt_value, completion)
completion = self.prompt.output_parser.parse_with_prompt(completion, prompt_value)
return completion
def create_outputs(

View File

@@ -1,4 +1,3 @@
from pydantic import BaseModel
from typing import Any
from langchain.guardrails.retry import naive_retry
@@ -6,13 +5,13 @@ from langchain.output_parsers import BaseOutputParser, OutputParserException
from langchain.schema import BaseLanguageModel, PromptValue
class GuardedOutputParser(BaseModel):
class RetriableOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
retry_llm: BaseLanguageModel
def parse(self, prompt_value: PromptValue, completion: str) -> Any:
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
@@ -22,4 +21,10 @@ class GuardedOutputParser(BaseModel):
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
return parsed_completion
def parse(self, completion: str):
raise NotImplementedError
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Optional
from pydantic import BaseModel
from langchain.schema import PromptValue
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@@ -13,6 +14,9 @@ class BaseOutputParser(BaseModel, ABC):
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def parse_with_prompt(self, completion: str, prompt: Optional[PromptValue] = None) -> Any:
return self.parse(completion)
def get_format_instructions(self) -> str:
raise NotImplementedError

View File

@@ -6,8 +6,6 @@ from typing import Any, Dict, List, NamedTuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.output_parsers import BaseOutputParser
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"