From a402de3daeca5d294360cffd7bd7a9c499f74f8e Mon Sep 17 00:00:00 2001 From: hmasdev <73353463+hmasdev@users.noreply.github.com> Date: Thu, 18 Jul 2024 05:34:46 +0900 Subject: [PATCH] langchain[patch]: fix wrong `dict` key in `OutputFixingParser`, `RetryOutputParser` and `RetryWithErrorOutputParser` (#23967) # Description This PR aims to solve a bug in `OutputFixingParser`, `RetryOutputParser` and `RetryWithErrorOutputParser` The bug is that the wrong keyword argument was given to `retry_chain`. The correct keyword argument is 'completion', but 'input' is used. This pull request makes the following changes: 1. correct a `dict` key given to `retry_chain`; 2. add a test when using the default prompt. - `NAIVE_FIX_PROMPT` for `OutputFixingParser`; - `NAIVE_RETRY_PROMPT` for `RetryOutputParser`; - `NAIVE_RETRY_WITH_ERROR_PROMPT` for `RetryWithErrorOutputParser`; 3. ~~add comments on `retry_chain` input and output types~~ clarify `InputType` and `OutputType` of `retry_chain` # Issue The bug is pointed out in https://github.com/langchain-ai/langchain/pull/19792#issuecomment-2196512928 --------- Co-authored-by: Erick Friis --- .../langchain/langchain/output_parsers/fix.py | 29 ++- .../langchain/output_parsers/retry.py | 31 ++- .../unit_tests/output_parsers/test_fix.py | 131 +++++++++++- .../unit_tests/output_parsers/test_retry.py | 201 +++++++++++++++++- 4 files changed, 357 insertions(+), 35 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 849200105d9..3a52f442626 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -7,12 +7,19 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import RunnableSerializable +from typing_extensions import TypedDict from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT T = TypeVar("T") +class OutputFixingParserRetryChainInput(TypedDict, total=False): + instructions: str + completion: str + error: str + + class OutputFixingParser(BaseOutputParser[T]): """Wrap a parser and try to fix parsing errors.""" @@ -23,7 +30,9 @@ class OutputFixingParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains - retry_chain: Union[RunnableSerializable, Any] + retry_chain: Union[ + RunnableSerializable[OutputFixingParserRetryChainInput, str], Any + ] """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" @@ -73,16 +82,16 @@ class OutputFixingParser(BaseOutputParser[T]): try: completion = self.retry_chain.invoke( dict( - instructions=self.parser.get_format_instructions(), # noqa: E501 - input=completion, + instructions=self.parser.get_format_instructions(), + completion=completion, error=repr(e), ) ) except (NotImplementedError, AttributeError): - # Case: self.parser does not have get_format_instructions # noqa: E501 + # Case: self.parser does not have get_format_instructions completion = self.retry_chain.invoke( dict( - input=completion, + completion=completion, error=repr(e), ) ) @@ -102,7 +111,7 @@ class OutputFixingParser(BaseOutputParser[T]): retries += 1 if self.legacy and hasattr(self.retry_chain, "arun"): completion = await self.retry_chain.arun( - instructions=self.parser.get_format_instructions(), # noqa: E501 + instructions=self.parser.get_format_instructions(), completion=completion, error=repr(e), ) @@ -110,16 +119,16 @@ class OutputFixingParser(BaseOutputParser[T]): try: completion = await self.retry_chain.ainvoke( dict( - instructions=self.parser.get_format_instructions(), # noqa: E501 - input=completion, + instructions=self.parser.get_format_instructions(), + completion=completion, error=repr(e), ) ) except (NotImplementedError, AttributeError): - # Case: self.parser does not have get_format_instructions # noqa: E501 + # Case: self.parser does not have get_format_instructions completion = await self.retry_chain.ainvoke( dict( - input=completion, + completion=completion, error=repr(e), ) ) diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index b82f1796571..7d5a383903a 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -8,6 +8,7 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import RunnableSerializable +from typing_extensions import TypedDict NAIVE_COMPLETION_RETRY = """Prompt: {prompt} @@ -34,6 +35,17 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( T = TypeVar("T") +class RetryOutputParserRetryChainInput(TypedDict): + prompt: str + completion: str + + +class RetryWithErrorOutputParserRetryChainInput(TypedDict): + prompt: str + completion: str + error: str + + class RetryOutputParser(BaseOutputParser[T]): """Wrap a parser and try to fix parsing errors. @@ -44,7 +56,7 @@ class RetryOutputParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains - retry_chain: Union[RunnableSerializable, Any] + retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any] """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" @@ -97,13 +109,12 @@ class RetryOutputParser(BaseOutputParser[T]): completion = self.retry_chain.run( prompt=prompt_value.to_string(), completion=completion, - error=repr(e), ) else: completion = self.retry_chain.invoke( dict( prompt=prompt_value.to_string(), - input=completion, + completion=completion, ) ) @@ -139,7 +150,7 @@ class RetryOutputParser(BaseOutputParser[T]): completion = await self.retry_chain.ainvoke( dict( prompt=prompt_value.to_string(), - input=completion, + completion=completion, ) ) @@ -174,8 +185,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" - # Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501 - retry_chain: Union[RunnableSerializable, Any] + # Should be an LLMChain but we want to avoid top-level imports from langchain.chains + retry_chain: Union[ + RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any + ] """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" @@ -204,7 +217,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): chain = prompt | llm return cls(parser=parser, retry_chain=chain, max_retries=max_retries) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501 + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: retries = 0 while retries <= self.max_retries: @@ -224,7 +237,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): else: completion = self.retry_chain.invoke( dict( - input=completion, + completion=completion, prompt=prompt_value.to_string(), error=repr(e), ) @@ -253,7 +266,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): completion = await self.retry_chain.ainvoke( dict( prompt=prompt_value.to_string(), - input=completion, + completion=completion, error=repr(e), ) ) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index eae0dfbdf95..a98d823a45b 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -1,12 +1,19 @@ -from typing import Any +from datetime import datetime as dt +from typing import Any, Callable, Dict, Optional, TypeVar import pytest from langchain_core.exceptions import OutputParserException -from langchain_core.runnables import RunnablePassthrough +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser +from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT +from langchain.pydantic_v1 import Extra + +T = TypeVar("T") class SuccessfulParseAfterRetries(BaseOutputParser[str]): @@ -22,7 +29,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]): return "parsed" -class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa +class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): def get_format_instructions(self) -> str: return "instructions" @@ -118,6 +125,120 @@ async def test_output_fixing_parser_aparse_fail() -> None: DatetimeOutputParser(), ], ) -def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501 - parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501 +def test_output_fixing_parser_output_type( + base_parser: BaseOutputParser, +) -> None: + parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) assert parser.OutputType is base_parser.OutputType + + +@pytest.mark.parametrize( + "input,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + DatetimeOutputParser(), + NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ), + ( + # Case: retry_chain.InputType does not have 'instructions' key + "2024/07/08", + DatetimeOutputParser(), + PromptTemplate.from_template("{completion}\n{error}") + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ), + ], +) +def test_output_fixing_parser_parse_with_retry_chain( + input: str, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + base_parser.Config.extra = Extra.allow # type: ignore + invoke_spy = mocker.spy(retry_chain, "invoke") + # NOTE: get_format_instructions of some parsers behave randomly + instructions = base_parser.get_format_instructions() + object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) + # test + parser = OutputFixingParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert parser.parse(input) == expected + invoke_spy.assert_called_once_with( + dict( + instructions=base_parser.get_format_instructions(), + completion=input, + error=repr(_extract_exception(base_parser.parse, input)), + ) + ) + + +@pytest.mark.parametrize( + "input,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + DatetimeOutputParser(), + NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ), + ( + # Case: retry_chain.InputType does not have 'instructions' key + "2024/07/08", + DatetimeOutputParser(), + PromptTemplate.from_template("{completion}\n{error}") + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ), + ], +) +async def test_output_fixing_parser_aparse_with_retry_chain( + input: str, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + base_parser.Config.extra = Extra.allow # type: ignore + ainvoke_spy = mocker.spy(retry_chain, "ainvoke") + # NOTE: get_format_instructions of some parsers behave randomly + instructions = base_parser.get_format_instructions() + object.__setattr__(base_parser, "get_format_instructions", lambda: instructions) + # test + parser = OutputFixingParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert (await parser.aparse(input)) == expected + ainvoke_spy.assert_called_once_with( + dict( + instructions=base_parser.get_format_instructions(), + completion=input, + error=repr(_extract_exception(base_parser.parse, input)), + ) + ) + + +def _extract_exception( + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> Optional[Exception]: + try: + func(*args, **kwargs) + except Exception as e: + return e + return None diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index 22c52735c36..7af3597f475 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -1,24 +1,29 @@ -from typing import Any +from datetime import datetime as dt +from typing import Any, Callable, Dict, Optional, TypeVar import pytest -from langchain_core.prompt_values import StringPromptValue -from langchain_core.runnables import RunnablePassthrough +from langchain_core.prompt_values import PromptValue, StringPromptValue +from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from pytest_mock import MockerFixture from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.retry import ( + NAIVE_RETRY_PROMPT, + NAIVE_RETRY_WITH_ERROR_PROMPT, BaseOutputParser, OutputParserException, RetryOutputParser, RetryWithErrorOutputParser, ) +from langchain.pydantic_v1 import Extra + +T = TypeVar("T") class SuccessfulParseAfterRetries(BaseOutputParser[str]): parse_count: int = 0 # Number of times parse has been called - attemp_count_before_success: ( - int # Number of times to fail before succeeding # noqa - ) + attemp_count_before_success: int # Number of times to fail before succeeding error_msg: str = "error" def parse(self, *args: Any, **kwargs: Any) -> str: @@ -37,7 +42,7 @@ def test_retry_output_parser_parse_with_prompt() -> None: max_retries=n, # n times to retry, that is, (n+1) times call legacy=False, ) - actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) assert actual == "parsed" assert base_parser.parse_count == n + 1 @@ -82,7 +87,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None: legacy=False, ) with pytest.raises(OutputParserException): - await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) assert base_parser.parse_count == n @@ -121,7 +126,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None: max_retries=n, # n times to retry, that is, (n+1) times call legacy=False, ) - actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) assert actual == "parsed" assert base_parser.parse_count == n + 1 @@ -156,7 +161,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None: assert base_parser.parse_count == n + 1 -async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501 +async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: n: int = 5 # Success on the (n+1)-th attempt base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) parser = RetryWithErrorOutputParser( @@ -166,7 +171,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: legacy=False, ) with pytest.raises(OutputParserException): - await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) assert base_parser.parse_count == n @@ -196,3 +201,177 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None: ) with pytest.raises(NotImplementedError): parser.parse("completion") + + +@pytest.mark.parametrize( + "input,prompt,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + StringPromptValue(text="dummy"), + DatetimeOutputParser(), + NAIVE_RETRY_PROMPT + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ) + ], +) +def test_retry_output_parser_parse_with_prompt_with_retry_chain( + input: str, + prompt: PromptValue, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + invoke_spy = mocker.spy(retry_chain, "invoke") + # test + parser = RetryOutputParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert parser.parse_with_prompt(input, prompt) == expected + invoke_spy.assert_called_once_with( + dict( + prompt=prompt.to_string(), + completion=input, + ) + ) + + +@pytest.mark.parametrize( + "input,prompt,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + StringPromptValue(text="dummy"), + DatetimeOutputParser(), + NAIVE_RETRY_PROMPT + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ) + ], +) +async def test_retry_output_parser_aparse_with_prompt_with_retry_chain( + input: str, + prompt: PromptValue, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + ainvoke_spy = mocker.spy(retry_chain, "ainvoke") + # test + parser = RetryOutputParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert (await parser.aparse_with_prompt(input, prompt)) == expected + ainvoke_spy.assert_called_once_with( + dict( + prompt=prompt.to_string(), + completion=input, + ) + ) + + +@pytest.mark.parametrize( + "input,prompt,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + StringPromptValue(text="dummy"), + DatetimeOutputParser(), + NAIVE_RETRY_WITH_ERROR_PROMPT + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ) + ], +) +def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain( + input: str, + prompt: PromptValue, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + invoke_spy = mocker.spy(retry_chain, "invoke") + # test + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert parser.parse_with_prompt(input, prompt) == expected + invoke_spy.assert_called_once_with( + dict( + prompt=prompt.to_string(), + completion=input, + error=repr(_extract_exception(base_parser.parse, input)), + ) + ) + + +@pytest.mark.parametrize( + "input,prompt,base_parser,retry_chain,expected", + [ + ( + "2024/07/08", + StringPromptValue(text="dummy"), + DatetimeOutputParser(), + NAIVE_RETRY_WITH_ERROR_PROMPT + | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"), + dt(2024, 7, 8), + ) + ], +) +async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain( + input: str, + prompt: PromptValue, + base_parser: BaseOutputParser[T], + retry_chain: Runnable[Dict[str, Any], str], + expected: T, + mocker: MockerFixture, +) -> None: + # preparation + # NOTE: Extra.allow is necessary in order to use spy and mock + retry_chain.Config.extra = Extra.allow # type: ignore + ainvoke_spy = mocker.spy(retry_chain, "ainvoke") + # test + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=retry_chain, + legacy=False, + ) + assert (await parser.aparse_with_prompt(input, prompt)) == expected + ainvoke_spy.assert_called_once_with( + dict( + prompt=prompt.to_string(), + completion=input, + error=repr(_extract_exception(base_parser.parse, input)), + ) + ) + + +def _extract_exception( + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> Optional[Exception]: + try: + func(*args, **kwargs) + except Exception as e: + return e + return None