langchain[patch]: fix wrong dict key in OutputFixingParser, RetryOutputParser and RetryWithErrorOutputParser (#23967)

# Description
This PR aims to solve a bug in `OutputFixingParser`, `RetryOutputParser`
and `RetryWithErrorOutputParser`
The bug is that the wrong keyword argument was given to `retry_chain`.
The correct keyword argument is 'completion', but 'input' is used.

This pull request makes the following changes:
1. correct a `dict` key given to `retry_chain`;
2. add a test when using the default prompt.
   - `NAIVE_FIX_PROMPT` for `OutputFixingParser`;
   - `NAIVE_RETRY_PROMPT` for `RetryOutputParser`;
   - `NAIVE_RETRY_WITH_ERROR_PROMPT` for `RetryWithErrorOutputParser`;
3. ~~add comments on `retry_chain` input and output types~~ clarify
`InputType` and `OutputType` of `retry_chain`

# Issue
The bug is pointed out in
https://github.com/langchain-ai/langchain/pull/19792#issuecomment-2196512928

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
hmasdev 2024-07-18 05:34:46 +09:00 committed by GitHub
parent a47f69a120
commit a402de3dae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 357 additions and 35 deletions

View File

@ -7,12 +7,19 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import RunnableSerializable
from typing_extensions import TypedDict
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
T = TypeVar("T")
class OutputFixingParserRetryChainInput(TypedDict, total=False):
instructions: str
completion: str
error: str
class OutputFixingParser(BaseOutputParser[T]):
"""Wrap a parser and try to fix parsing errors."""
@ -23,7 +30,9 @@ class OutputFixingParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[RunnableSerializable, Any]
retry_chain: Union[
RunnableSerializable[OutputFixingParserRetryChainInput, str], Any
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@ -73,16 +82,16 @@ class OutputFixingParser(BaseOutputParser[T]):
try:
completion = self.retry_chain.invoke(
dict(
instructions=self.parser.get_format_instructions(), # noqa: E501
input=completion,
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions # noqa: E501
# Case: self.parser does not have get_format_instructions
completion = self.retry_chain.invoke(
dict(
input=completion,
completion=completion,
error=repr(e),
)
)
@ -102,7 +111,7 @@ class OutputFixingParser(BaseOutputParser[T]):
retries += 1
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(), # noqa: E501
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
@ -110,16 +119,16 @@ class OutputFixingParser(BaseOutputParser[T]):
try:
completion = await self.retry_chain.ainvoke(
dict(
instructions=self.parser.get_format_instructions(), # noqa: E501
input=completion,
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions # noqa: E501
# Case: self.parser does not have get_format_instructions
completion = await self.retry_chain.ainvoke(
dict(
input=completion,
completion=completion,
error=repr(e),
)
)

View File

@ -8,6 +8,7 @@ from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableSerializable
from typing_extensions import TypedDict
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
@ -34,6 +35,17 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
T = TypeVar("T")
class RetryOutputParserRetryChainInput(TypedDict):
prompt: str
completion: str
class RetryWithErrorOutputParserRetryChainInput(TypedDict):
prompt: str
completion: str
error: str
class RetryOutputParser(BaseOutputParser[T]):
"""Wrap a parser and try to fix parsing errors.
@ -44,7 +56,7 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[RunnableSerializable, Any]
retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@ -97,13 +109,12 @@ class RetryOutputParser(BaseOutputParser[T]):
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else:
completion = self.retry_chain.invoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
completion=completion,
)
)
@ -139,7 +150,7 @@ class RetryOutputParser(BaseOutputParser[T]):
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
completion=completion,
)
)
@ -174,8 +185,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501
retry_chain: Union[RunnableSerializable, Any]
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Union[
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@ -204,7 +217,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
chain = prompt | llm
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
retries = 0
while retries <= self.max_retries:
@ -224,7 +237,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
else:
completion = self.retry_chain.invoke(
dict(
input=completion,
completion=completion,
prompt=prompt_value.to_string(),
error=repr(e),
)
@ -253,7 +266,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
completion=completion,
error=repr(e),
)
)

View File

@ -1,12 +1,19 @@
from typing import Any
from datetime import datetime as dt
from typing import Any, Callable, Dict, Optional, TypeVar
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pytest_mock import MockerFixture
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.pydantic_v1 import Extra
T = TypeVar("T")
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
@ -22,7 +29,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
return "parsed"
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries):
def get_format_instructions(self) -> str:
return "instructions"
@ -118,6 +125,120 @@ async def test_output_fixing_parser_aparse_fail() -> None:
DatetimeOutputParser(),
],
)
def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501
def test_output_fixing_parser_output_type(
base_parser: BaseOutputParser,
) -> None:
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough())
assert parser.OutputType is base_parser.OutputType
@pytest.mark.parametrize(
"input,base_parser,retry_chain,expected",
[
(
"2024/07/08",
DatetimeOutputParser(),
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
),
(
# Case: retry_chain.InputType does not have 'instructions' key
"2024/07/08",
DatetimeOutputParser(),
PromptTemplate.from_template("{completion}\n{error}")
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
),
],
)
def test_output_fixing_parser_parse_with_retry_chain(
input: str,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
base_parser.Config.extra = Extra.allow # type: ignore
invoke_spy = mocker.spy(retry_chain, "invoke")
# NOTE: get_format_instructions of some parsers behave randomly
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse(input) == expected
invoke_spy.assert_called_once_with(
dict(
instructions=base_parser.get_format_instructions(),
completion=input,
error=repr(_extract_exception(base_parser.parse, input)),
)
)
@pytest.mark.parametrize(
"input,base_parser,retry_chain,expected",
[
(
"2024/07/08",
DatetimeOutputParser(),
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
),
(
# Case: retry_chain.InputType does not have 'instructions' key
"2024/07/08",
DatetimeOutputParser(),
PromptTemplate.from_template("{completion}\n{error}")
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
),
],
)
async def test_output_fixing_parser_aparse_with_retry_chain(
input: str,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
base_parser.Config.extra = Extra.allow # type: ignore
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
# NOTE: get_format_instructions of some parsers behave randomly
instructions = base_parser.get_format_instructions()
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
# test
parser = OutputFixingParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse(input)) == expected
ainvoke_spy.assert_called_once_with(
dict(
instructions=base_parser.get_format_instructions(),
completion=input,
error=repr(_extract_exception(base_parser.parse, input)),
)
)
def _extract_exception(
func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> Optional[Exception]:
try:
func(*args, **kwargs)
except Exception as e:
return e
return None

View File

@ -1,24 +1,29 @@
from typing import Any
from datetime import datetime as dt
from typing import Any, Callable, Dict, Optional, TypeVar
import pytest
from langchain_core.prompt_values import StringPromptValue
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pytest_mock import MockerFixture
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
from langchain.output_parsers.retry import (
NAIVE_RETRY_PROMPT,
NAIVE_RETRY_WITH_ERROR_PROMPT,
BaseOutputParser,
OutputParserException,
RetryOutputParser,
RetryWithErrorOutputParser,
)
from langchain.pydantic_v1 import Extra
T = TypeVar("T")
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
parse_count: int = 0 # Number of times parse has been called
attemp_count_before_success: (
int # Number of times to fail before succeeding # noqa
)
attemp_count_before_success: int # Number of times to fail before succeeding
error_msg: str = "error"
def parse(self, *args: Any, **kwargs: Any) -> str:
@ -37,7 +42,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
assert actual == "parsed"
assert base_parser.parse_count == n + 1
@ -82,7 +87,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
legacy=False,
)
with pytest.raises(OutputParserException):
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
assert base_parser.parse_count == n
@ -121,7 +126,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
assert actual == "parsed"
assert base_parser.parse_count == n + 1
@ -156,7 +161,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
assert base_parser.parse_count == n + 1
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
@ -166,7 +171,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
legacy=False,
)
with pytest.raises(OutputParserException):
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
assert base_parser.parse_count == n
@ -196,3 +201,177 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
)
with pytest.raises(NotImplementedError):
parser.parse("completion")
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
StringPromptValue(text="dummy"),
DatetimeOutputParser(),
NAIVE_RETRY_PROMPT
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
)
],
)
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
input: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
invoke_spy = mocker.spy(retry_chain, "invoke")
# test
parser = RetryOutputParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse_with_prompt(input, prompt) == expected
invoke_spy.assert_called_once_with(
dict(
prompt=prompt.to_string(),
completion=input,
)
)
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
StringPromptValue(text="dummy"),
DatetimeOutputParser(),
NAIVE_RETRY_PROMPT
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
)
],
)
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
input: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
# test
parser = RetryOutputParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse_with_prompt(input, prompt)) == expected
ainvoke_spy.assert_called_once_with(
dict(
prompt=prompt.to_string(),
completion=input,
)
)
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
StringPromptValue(text="dummy"),
DatetimeOutputParser(),
NAIVE_RETRY_WITH_ERROR_PROMPT
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
)
],
)
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
input: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
invoke_spy = mocker.spy(retry_chain, "invoke")
# test
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert parser.parse_with_prompt(input, prompt) == expected
invoke_spy.assert_called_once_with(
dict(
prompt=prompt.to_string(),
completion=input,
error=repr(_extract_exception(base_parser.parse, input)),
)
)
@pytest.mark.parametrize(
"input,prompt,base_parser,retry_chain,expected",
[
(
"2024/07/08",
StringPromptValue(text="dummy"),
DatetimeOutputParser(),
NAIVE_RETRY_WITH_ERROR_PROMPT
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
dt(2024, 7, 8),
)
],
)
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
input: str,
prompt: PromptValue,
base_parser: BaseOutputParser[T],
retry_chain: Runnable[Dict[str, Any], str],
expected: T,
mocker: MockerFixture,
) -> None:
# preparation
# NOTE: Extra.allow is necessary in order to use spy and mock
retry_chain.Config.extra = Extra.allow # type: ignore
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
# test
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=retry_chain,
legacy=False,
)
assert (await parser.aparse_with_prompt(input, prompt)) == expected
ainvoke_spy.assert_called_once_with(
dict(
prompt=prompt.to_string(),
completion=input,
error=repr(_extract_exception(base_parser.parse, input)),
)
)
def _extract_exception(
func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> Optional[Exception]:
try:
func(*args, **kwargs)
except Exception as e:
return e
return None