mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
langchain[patch]: fix wrong dict
key in OutputFixingParser
, RetryOutputParser
and RetryWithErrorOutputParser
(#23967)
# Description This PR aims to solve a bug in `OutputFixingParser`, `RetryOutputParser` and `RetryWithErrorOutputParser` The bug is that the wrong keyword argument was given to `retry_chain`. The correct keyword argument is 'completion', but 'input' is used. This pull request makes the following changes: 1. correct a `dict` key given to `retry_chain`; 2. add a test when using the default prompt. - `NAIVE_FIX_PROMPT` for `OutputFixingParser`; - `NAIVE_RETRY_PROMPT` for `RetryOutputParser`; - `NAIVE_RETRY_WITH_ERROR_PROMPT` for `RetryWithErrorOutputParser`; 3. ~~add comments on `retry_chain` input and output types~~ clarify `InputType` and `OutputType` of `retry_chain` # Issue The bug is pointed out in https://github.com/langchain-ai/langchain/pull/19792#issuecomment-2196512928 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a47f69a120
commit
a402de3dae
@ -7,12 +7,19 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputFixingParserRetryChainInput(TypedDict, total=False):
|
||||
instructions: str
|
||||
completion: str
|
||||
error: str
|
||||
|
||||
|
||||
class OutputFixingParser(BaseOutputParser[T]):
|
||||
"""Wrap a parser and try to fix parsing errors."""
|
||||
|
||||
@ -23,7 +30,9 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Union[RunnableSerializable, Any]
|
||||
retry_chain: Union[
|
||||
RunnableSerializable[OutputFixingParserRetryChainInput, str], Any
|
||||
]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
@ -73,16 +82,16 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
try:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
||||
input=completion,
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions # noqa: E501
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
input=completion,
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
)
|
||||
@ -102,7 +111,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
@ -110,16 +119,16 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
try:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
||||
input=completion,
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions # noqa: E501
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
input=completion,
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||
{prompt}
|
||||
@ -34,6 +35,17 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryOutputParserRetryChainInput(TypedDict):
|
||||
prompt: str
|
||||
completion: str
|
||||
|
||||
|
||||
class RetryWithErrorOutputParserRetryChainInput(TypedDict):
|
||||
prompt: str
|
||||
completion: str
|
||||
error: str
|
||||
|
||||
|
||||
class RetryOutputParser(BaseOutputParser[T]):
|
||||
"""Wrap a parser and try to fix parsing errors.
|
||||
|
||||
@ -44,7 +56,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Union[RunnableSerializable, Any]
|
||||
retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
@ -97,13 +109,12 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
input=completion,
|
||||
completion=completion,
|
||||
)
|
||||
)
|
||||
|
||||
@ -139,7 +150,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
input=completion,
|
||||
completion=completion,
|
||||
)
|
||||
)
|
||||
|
||||
@ -174,8 +185,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501
|
||||
retry_chain: Union[RunnableSerializable, Any]
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Union[
|
||||
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
|
||||
]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
@ -204,7 +217,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
chain = prompt | llm
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
@ -224,7 +237,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
dict(
|
||||
input=completion,
|
||||
completion=completion,
|
||||
prompt=prompt_value.to_string(),
|
||||
error=repr(e),
|
||||
)
|
||||
@ -253,7 +266,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
dict(
|
||||
prompt=prompt_value.to_string(),
|
||||
input=completion,
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
)
|
||||
|
@ -1,12 +1,19 @@
|
||||
from typing import Any
|
||||
from datetime import datetime as dt
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
import pytest
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||
from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
from langchain.pydantic_v1 import Extra
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
@ -22,7 +29,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
return "parsed"
|
||||
|
||||
|
||||
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa
|
||||
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries):
|
||||
def get_format_instructions(self) -> str:
|
||||
return "instructions"
|
||||
|
||||
@ -118,6 +125,120 @@ async def test_output_fixing_parser_aparse_fail() -> None:
|
||||
DatetimeOutputParser(),
|
||||
],
|
||||
)
|
||||
def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501
|
||||
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501
|
||||
def test_output_fixing_parser_output_type(
|
||||
base_parser: BaseOutputParser,
|
||||
) -> None:
|
||||
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough())
|
||||
assert parser.OutputType is base_parser.OutputType
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
),
|
||||
(
|
||||
# Case: retry_chain.InputType does not have 'instructions' key
|
||||
"2024/07/08",
|
||||
DatetimeOutputParser(),
|
||||
PromptTemplate.from_template("{completion}\n{error}")
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_output_fixing_parser_parse_with_retry_chain(
|
||||
input: str,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
base_parser.Config.extra = Extra.allow # type: ignore
|
||||
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||
# NOTE: get_format_instructions of some parsers behave randomly
|
||||
instructions = base_parser.get_format_instructions()
|
||||
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
|
||||
# test
|
||||
parser = OutputFixingParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse(input) == expected
|
||||
invoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
instructions=base_parser.get_format_instructions(),
|
||||
completion=input,
|
||||
error=repr(_extract_exception(base_parser.parse, input)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
),
|
||||
(
|
||||
# Case: retry_chain.InputType does not have 'instructions' key
|
||||
"2024/07/08",
|
||||
DatetimeOutputParser(),
|
||||
PromptTemplate.from_template("{completion}\n{error}")
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_output_fixing_parser_aparse_with_retry_chain(
|
||||
input: str,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
base_parser.Config.extra = Extra.allow # type: ignore
|
||||
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||
# NOTE: get_format_instructions of some parsers behave randomly
|
||||
instructions = base_parser.get_format_instructions()
|
||||
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
|
||||
# test
|
||||
parser = OutputFixingParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse(input)) == expected
|
||||
ainvoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
instructions=base_parser.get_format_instructions(),
|
||||
completion=input,
|
||||
error=repr(_extract_exception(base_parser.parse, input)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_exception(
|
||||
func: Callable[..., Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[Exception]:
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
return None
|
||||
|
@ -1,24 +1,29 @@
|
||||
from typing import Any
|
||||
from datetime import datetime as dt
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
import pytest
|
||||
from langchain_core.prompt_values import StringPromptValue
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||
from langchain.output_parsers.retry import (
|
||||
NAIVE_RETRY_PROMPT,
|
||||
NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||
BaseOutputParser,
|
||||
OutputParserException,
|
||||
RetryOutputParser,
|
||||
RetryWithErrorOutputParser,
|
||||
)
|
||||
from langchain.pydantic_v1 import Extra
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
parse_count: int = 0 # Number of times parse has been called
|
||||
attemp_count_before_success: (
|
||||
int # Number of times to fail before succeeding # noqa
|
||||
)
|
||||
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||
error_msg: str = "error"
|
||||
|
||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||
@ -37,7 +42,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
legacy=False,
|
||||
)
|
||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||
assert actual == "parsed"
|
||||
assert base_parser.parse_count == n + 1
|
||||
|
||||
@ -82,7 +87,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
||||
legacy=False,
|
||||
)
|
||||
with pytest.raises(OutputParserException):
|
||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||
assert base_parser.parse_count == n
|
||||
|
||||
|
||||
@ -121,7 +126,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||
legacy=False,
|
||||
)
|
||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||
assert actual == "parsed"
|
||||
assert base_parser.parse_count == n + 1
|
||||
|
||||
@ -156,7 +161,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
||||
assert base_parser.parse_count == n + 1
|
||||
|
||||
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||
n: int = 5 # Success on the (n+1)-th attempt
|
||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||
parser = RetryWithErrorOutputParser(
|
||||
@ -166,7 +171,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||
legacy=False,
|
||||
)
|
||||
with pytest.raises(OutputParserException):
|
||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||
assert base_parser.parse_count == n
|
||||
|
||||
|
||||
@ -196,3 +201,177 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
||||
)
|
||||
with pytest.raises(NotImplementedError):
|
||||
parser.parse("completion")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
StringPromptValue(text="dummy"),
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_RETRY_PROMPT
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||
# test
|
||||
parser = RetryOutputParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse_with_prompt(input, prompt) == expected
|
||||
invoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
prompt=prompt.to_string(),
|
||||
completion=input,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
StringPromptValue(text="dummy"),
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_RETRY_PROMPT
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
)
|
||||
],
|
||||
)
|
||||
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||
# test
|
||||
parser = RetryOutputParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||
ainvoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
prompt=prompt.to_string(),
|
||||
completion=input,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
StringPromptValue(text="dummy"),
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_RETRY_WITH_ERROR_PROMPT
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||
# test
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert parser.parse_with_prompt(input, prompt) == expected
|
||||
invoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
prompt=prompt.to_string(),
|
||||
completion=input,
|
||||
error=repr(_extract_exception(base_parser.parse, input)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,prompt,base_parser,retry_chain,expected",
|
||||
[
|
||||
(
|
||||
"2024/07/08",
|
||||
StringPromptValue(text="dummy"),
|
||||
DatetimeOutputParser(),
|
||||
NAIVE_RETRY_WITH_ERROR_PROMPT
|
||||
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||
dt(2024, 7, 8),
|
||||
)
|
||||
],
|
||||
)
|
||||
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
||||
input: str,
|
||||
prompt: PromptValue,
|
||||
base_parser: BaseOutputParser[T],
|
||||
retry_chain: Runnable[Dict[str, Any], str],
|
||||
expected: T,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# preparation
|
||||
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||
# test
|
||||
parser = RetryWithErrorOutputParser(
|
||||
parser=base_parser,
|
||||
retry_chain=retry_chain,
|
||||
legacy=False,
|
||||
)
|
||||
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||
ainvoke_spy.assert_called_once_with(
|
||||
dict(
|
||||
prompt=prompt.to_string(),
|
||||
completion=input,
|
||||
error=repr(_extract_exception(base_parser.parse, input)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_exception(
|
||||
func: Callable[..., Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Optional[Exception]:
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user