mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
Supported RetryOutputParser & RetryWithErrorOutputParser max_retries (#11903)
Description: Supported RetryOutputParser & RetryWithErrorOutputParser max_retries - max_retries: Maximum number of retries to parser. Issue: None Dependencies: None Tag maintainer: @baskaryan Twitter handle:
This commit is contained in:
parent
008c7df80d
commit
a6b483dcbc
@ -17,9 +17,12 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
return True
|
||||
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -35,7 +38,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parser.
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
OutputFixingParser
|
||||
|
@ -48,6 +48,8 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -55,11 +57,23 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryOutputParser[T]:
|
||||
"""Create an OutputFixingParser from a language model and a parser.
|
||||
|
||||
Args:
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
RetryOutputParser
|
||||
"""
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
@ -71,15 +85,21 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
@ -91,15 +111,21 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
new_completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
def parse(self, completion: str) -> T:
|
||||
raise NotImplementedError(
|
||||
@ -125,8 +151,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
"""
|
||||
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -134,6 +164,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryWithErrorOutputParser[T]:
|
||||
"""Create a RetryWithErrorOutputParser from an LLM.
|
||||
|
||||
@ -141,6 +172,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
llm: The LLM to use to retry the completion.
|
||||
parser: The parser to use to parse the output.
|
||||
prompt: The prompt to use to retry the completion.
|
||||
max_retries: The maximum number of times to retry the completion.
|
||||
|
||||
Returns:
|
||||
A RetryWithErrorOutputParser.
|
||||
@ -148,29 +180,45 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
def parse(self, completion: str) -> T:
|
||||
raise NotImplementedError(
|
||||
|
Loading…
Reference in New Issue
Block a user