mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
try make guarded/retriable output parser an instance of parser
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "10d915e4",
|
||||
"id": "4e75cce5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c09a2430",
|
||||
"id": "a930f49b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Guarded Output Parsers\n",
|
||||
@@ -22,15 +22,15 @@
|
||||
"\n",
|
||||
"Unfortunately, small models to date don't have capacity to generate well-formed, schema-adherent json while large models still sometimes fail.\n",
|
||||
"\n",
|
||||
"In this notebook, we showcase a `GuardedOutputParser` which can be dropped in for an `OutputParser` in an `LLMChain`. It will catch errors at parsing time and try resolve them, initially by re-invoking an LLM.\n",
|
||||
"In this notebook, we showcase a \"Guarded\" `OutputParser` which can be dropped in for an `OutputParser` in an `LLMChain`. It will catch errors at parsing time and try resolve them, initially by re-invoking an LLM\n",
|
||||
"\n",
|
||||
"Below are some examples:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "3db24806",
|
||||
"execution_count": 1,
|
||||
"id": "8e6bd7ef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -38,7 +38,7 @@
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.guardrails.parsing import GuardedOutputParser\n",
|
||||
"from langchain.guardrails.parsing import RetriableOutputParser\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.output_parsers import PydanticOutputParser, OutputParserException\n",
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
@@ -46,7 +46,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8d40796c",
|
||||
"id": "8bc39653",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1st example"
|
||||
@@ -54,8 +54,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "eec61f2d",
|
||||
"execution_count": 2,
|
||||
"id": "01337598",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -69,8 +69,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "42b8f744",
|
||||
"execution_count": 3,
|
||||
"id": "618dca6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -89,8 +89,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "67a1d356",
|
||||
"execution_count": 4,
|
||||
"id": "6ccff983",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -112,7 +112,7 @@
|
||||
"\u001b[0m\n",
|
||||
"Dang!\n",
|
||||
"Failed to parse FloatArray from completion \n",
|
||||
"Fiboacci is a sequence of numbers defined by the Fibonacci sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21.. Got: Expecting value: line 1 column 1 (char 0)\n"
|
||||
"Fiboacci is a sequence of numbers that are the sum of the previous two numbers in the sequence.. Got: Expecting value: line 1 column 1 (char 0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -132,8 +132,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "c568ebbb",
|
||||
"execution_count": 5,
|
||||
"id": "3c68ad59",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -160,17 +160,17 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"FloatArray(values=[0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0, 144.0])"
|
||||
"FloatArray(values=[1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# We can replace the parser with a guarded parser that tries to fix errors with a bigger model.\n",
|
||||
"guarded_parser = GuardedOutputParser(\n",
|
||||
"guarded_parser = RetriableOutputParser(\n",
|
||||
" parser=parser, retry_llm=OpenAI(model_name=\"text-davinci-003\"))\n",
|
||||
"prompt.output_parser = guarded_parser\n",
|
||||
"\n",
|
||||
@@ -179,27 +179,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aa0e7b3c",
|
||||
"id": "f5fc247e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example is demonstrative though. If your goal is to generate data structures, probably you'll want to start a large enough model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3edbc886",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d83cfb61",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -15,8 +15,6 @@ from langchain.schema import (
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.output_parsers import BaseOutputParser
|
||||
from langchain.guardrails.parsing import GuardedOutputParser
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@@ -140,13 +138,7 @@ class LLMChain(Chain, BaseModel):
|
||||
"""Get the final output from a list of generations for a prompt."""
|
||||
completion = generations[0].text
|
||||
if self.prompt.output_parser:
|
||||
parser = self.prompt.output_parser
|
||||
if isinstance(parser, BaseOutputParser):
|
||||
completion = parser.parse(completion)
|
||||
elif isinstance(parser, GuardedOutputParser):
|
||||
# TODO: not ideal to hide retry calls from user. can we expose it in colored log?
|
||||
completion = parser.parse(prompt_value, completion)
|
||||
|
||||
completion = self.prompt.output_parser.parse_with_prompt(completion, prompt_value)
|
||||
return completion
|
||||
|
||||
def create_outputs(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
from langchain.guardrails.retry import naive_retry
|
||||
@@ -6,13 +5,13 @@ from langchain.output_parsers import BaseOutputParser, OutputParserException
|
||||
from langchain.schema import BaseLanguageModel, PromptValue
|
||||
|
||||
|
||||
class GuardedOutputParser(BaseModel):
|
||||
class RetriableOutputParser(BaseOutputParser):
|
||||
"""Wraps a parser and tries to fix parsing errors."""
|
||||
|
||||
parser: BaseOutputParser
|
||||
retry_llm: BaseLanguageModel
|
||||
|
||||
def parse(self, prompt_value: PromptValue, completion: str) -> Any:
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
@@ -22,4 +21,10 @@ class GuardedOutputParser(BaseModel):
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
|
||||
return parsed_completion
|
||||
return parsed_completion
|
||||
|
||||
def parse(self, completion: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.schema import PromptValue
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC):
|
||||
"""Class to parse the output of an LLM call."""
|
||||
@@ -13,6 +14,9 @@ class BaseOutputParser(BaseModel, ABC):
|
||||
def parse(self, text: str) -> Any:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt: Optional[PromptValue] = None) -> Any:
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ from typing import Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.output_parsers import BaseOutputParser
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
|
||||
Reference in New Issue
Block a user