diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 0b66e750a5d..b258ed3344a 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -17,9 +17,12 @@ class OutputFixingParser(BaseOutputParser[T]): return True parser: BaseOutputParser[T] + """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Any + """The LLMChain to use to retry the completion.""" max_retries: int = 1 + """The maximum number of times to retry the parse.""" @classmethod def from_llm( @@ -35,7 +38,7 @@ class OutputFixingParser(BaseOutputParser[T]): llm: llm to use for fixing parser: parser to use for parsing prompt: prompt to use for fixing - max_retries: Maximum number of retries to parser. + max_retries: Maximum number of retries to parse. Returns: OutputFixingParser diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index b40a7fdc365..c78f2469b99 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -48,6 +48,8 @@ class RetryOutputParser(BaseOutputParser[T]): # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Any """The LLMChain to use to retry the completion.""" + max_retries: int = 1 + """The maximum number of times to retry the parse.""" @classmethod def from_llm( @@ -55,11 +57,23 @@ class RetryOutputParser(BaseOutputParser[T]): llm: BaseLanguageModel, parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, + max_retries: int = 1, ) -> RetryOutputParser[T]: + """Create an OutputFixingParser from a language model and a parser. + + Args: + llm: llm to use for fixing + parser: parser to use for parsing + prompt: prompt to use for fixing + max_retries: Maximum number of retries to parse. + + Returns: + RetryOutputParser + """ from langchain.chains.llm import LLMChain chain = LLMChain(llm=llm, prompt=prompt) - return cls(parser=parser, retry_chain=chain) + return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: """Parse the output of an LLM call using a wrapped parser. @@ -71,15 +85,21 @@ class RetryOutputParser(BaseOutputParser[T]): Returns: The parsed completion. """ - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException: - new_completion = self.retry_chain.run( - prompt=prompt_value.to_string(), completion=completion - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return self.parser.parse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), completion=completion + ) + + raise OutputParserException("Failed to parse") async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: """Parse the output of an LLM call using a wrapped parser. @@ -91,15 +111,21 @@ class RetryOutputParser(BaseOutputParser[T]): Returns: The parsed completion. """ - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException: - new_completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), completion=completion - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return await self.parser.aparse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), completion=completion + ) + + raise OutputParserException("Failed to parse") def parse(self, completion: str) -> T: raise NotImplementedError( @@ -125,8 +151,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): """ parser: BaseOutputParser[T] + """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains retry_chain: Any + """The LLMChain to use to retry the completion.""" + max_retries: int = 1 + """The maximum number of times to retry the parse.""" @classmethod def from_llm( @@ -134,6 +164,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): llm: BaseLanguageModel, parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, + max_retries: int = 1, ) -> RetryWithErrorOutputParser[T]: """Create a RetryWithErrorOutputParser from an LLM. @@ -141,6 +172,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): llm: The LLM to use to retry the completion. parser: The parser to use to parse the output. prompt: The prompt to use to retry the completion. + max_retries: The maximum number of times to retry the completion. Returns: A RetryWithErrorOutputParser. @@ -148,29 +180,45 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): from langchain.chains.llm import LLMChain chain = LLMChain(llm=llm, prompt=prompt) - return cls(parser=parser, retry_chain=chain) + return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException as e: - new_completion = self.retry_chain.run( - prompt=prompt_value.to_string(), completion=completion, error=repr(e) - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return self.parser.parse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + + raise OutputParserException("Failed to parse") async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: - try: - parsed_completion = self.parser.parse(completion) - except OutputParserException as e: - new_completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), completion=completion, error=repr(e) - ) - parsed_completion = self.parser.parse(new_completion) + retries = 0 - return parsed_completion + while retries <= self.max_retries: + try: + return await self.parser.aparse(completion) + except OutputParserException as e: + if retries == self.max_retries: + raise e + else: + retries += 1 + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + + raise OutputParserException("Failed to parse") def parse(self, completion: str) -> T: raise NotImplementedError(