mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
langchain[patch]: fix wrong dict
key in OutputFixingParser
, RetryOutputParser
and RetryWithErrorOutputParser
(#23967)
# Description This PR aims to solve a bug in `OutputFixingParser`, `RetryOutputParser` and `RetryWithErrorOutputParser` The bug is that the wrong keyword argument was given to `retry_chain`. The correct keyword argument is 'completion', but 'input' is used. This pull request makes the following changes: 1. correct a `dict` key given to `retry_chain`; 2. add a test when using the default prompt. - `NAIVE_FIX_PROMPT` for `OutputFixingParser`; - `NAIVE_RETRY_PROMPT` for `RetryOutputParser`; - `NAIVE_RETRY_WITH_ERROR_PROMPT` for `RetryWithErrorOutputParser`; 3. ~~add comments on `retry_chain` input and output types~~ clarify `InputType` and `OutputType` of `retry_chain` # Issue The bug is pointed out in https://github.com/langchain-ai/langchain/pull/19792#issuecomment-2196512928 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a47f69a120
commit
a402de3dae
@ -7,12 +7,19 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import RunnableSerializable
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFixingParserRetryChainInput(TypedDict, total=False):
|
||||||
|
instructions: str
|
||||||
|
completion: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
class OutputFixingParser(BaseOutputParser[T]):
|
class OutputFixingParser(BaseOutputParser[T]):
|
||||||
"""Wrap a parser and try to fix parsing errors."""
|
"""Wrap a parser and try to fix parsing errors."""
|
||||||
|
|
||||||
@ -23,7 +30,9 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
parser: BaseOutputParser[T]
|
parser: BaseOutputParser[T]
|
||||||
"""The parser to use to parse the output."""
|
"""The parser to use to parse the output."""
|
||||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||||
retry_chain: Union[RunnableSerializable, Any]
|
retry_chain: Union[
|
||||||
|
RunnableSerializable[OutputFixingParserRetryChainInput, str], Any
|
||||||
|
]
|
||||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||||
max_retries: int = 1
|
max_retries: int = 1
|
||||||
"""The maximum number of times to retry the parse."""
|
"""The maximum number of times to retry the parse."""
|
||||||
@ -73,16 +82,16 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
try:
|
try:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
dict(
|
||||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
instructions=self.parser.get_format_instructions(),
|
||||||
input=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
# Case: self.parser does not have get_format_instructions # noqa: E501
|
# Case: self.parser does not have get_format_instructions
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
dict(
|
||||||
input=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -102,7 +111,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
retries += 1
|
retries += 1
|
||||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||||
completion = await self.retry_chain.arun(
|
completion = await self.retry_chain.arun(
|
||||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
instructions=self.parser.get_format_instructions(),
|
||||||
completion=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
@ -110,16 +119,16 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
try:
|
try:
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
dict(
|
||||||
instructions=self.parser.get_format_instructions(), # noqa: E501
|
instructions=self.parser.get_format_instructions(),
|
||||||
input=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
# Case: self.parser does not have get_format_instructions # noqa: E501
|
# Case: self.parser does not have get_format_instructions
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
dict(
|
||||||
input=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.output_parsers import BaseOutputParser
|
|||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import RunnableSerializable
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||||
{prompt}
|
{prompt}
|
||||||
@ -34,6 +35,17 @@ NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class RetryOutputParserRetryChainInput(TypedDict):
|
||||||
|
prompt: str
|
||||||
|
completion: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetryWithErrorOutputParserRetryChainInput(TypedDict):
|
||||||
|
prompt: str
|
||||||
|
completion: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
class RetryOutputParser(BaseOutputParser[T]):
|
class RetryOutputParser(BaseOutputParser[T]):
|
||||||
"""Wrap a parser and try to fix parsing errors.
|
"""Wrap a parser and try to fix parsing errors.
|
||||||
|
|
||||||
@ -44,7 +56,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
parser: BaseOutputParser[T]
|
parser: BaseOutputParser[T]
|
||||||
"""The parser to use to parse the output."""
|
"""The parser to use to parse the output."""
|
||||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||||
retry_chain: Union[RunnableSerializable, Any]
|
retry_chain: Union[RunnableSerializable[RetryOutputParserRetryChainInput, str], Any]
|
||||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||||
max_retries: int = 1
|
max_retries: int = 1
|
||||||
"""The maximum number of times to retry the parse."""
|
"""The maximum number of times to retry the parse."""
|
||||||
@ -97,13 +109,12 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
completion = self.retry_chain.run(
|
completion = self.retry_chain.run(
|
||||||
prompt=prompt_value.to_string(),
|
prompt=prompt_value.to_string(),
|
||||||
completion=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
dict(
|
||||||
prompt=prompt_value.to_string(),
|
prompt=prompt_value.to_string(),
|
||||||
input=completion,
|
completion=completion,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -139,7 +150,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
dict(
|
||||||
prompt=prompt_value.to_string(),
|
prompt=prompt_value.to_string(),
|
||||||
input=completion,
|
completion=completion,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,8 +185,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
parser: BaseOutputParser[T]
|
parser: BaseOutputParser[T]
|
||||||
"""The parser to use to parse the output."""
|
"""The parser to use to parse the output."""
|
||||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501
|
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||||
retry_chain: Union[RunnableSerializable, Any]
|
retry_chain: Union[
|
||||||
|
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str], Any
|
||||||
|
]
|
||||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||||
max_retries: int = 1
|
max_retries: int = 1
|
||||||
"""The maximum number of times to retry the parse."""
|
"""The maximum number of times to retry the parse."""
|
||||||
@ -204,7 +217,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
chain = prompt | llm
|
chain = prompt | llm
|
||||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||||
|
|
||||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501
|
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||||
retries = 0
|
retries = 0
|
||||||
|
|
||||||
while retries <= self.max_retries:
|
while retries <= self.max_retries:
|
||||||
@ -224,7 +237,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
else:
|
else:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
dict(
|
||||||
input=completion,
|
completion=completion,
|
||||||
prompt=prompt_value.to_string(),
|
prompt=prompt_value.to_string(),
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
@ -253,7 +266,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
dict(
|
||||||
prompt=prompt_value.to_string(),
|
prompt=prompt_value.to_string(),
|
||||||
input=completion,
|
completion=completion,
|
||||||
error=repr(e),
|
error=repr(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
from typing import Any
|
from datetime import datetime as dt
|
||||||
|
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||||
from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser
|
from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser
|
||||||
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||||
|
from langchain.pydantic_v1 import Extra
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||||
@ -22,7 +29,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
|||||||
return "parsed"
|
return "parsed"
|
||||||
|
|
||||||
|
|
||||||
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa
|
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries):
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
return "instructions"
|
return "instructions"
|
||||||
|
|
||||||
@ -118,6 +125,120 @@ async def test_output_fixing_parser_aparse_fail() -> None:
|
|||||||
DatetimeOutputParser(),
|
DatetimeOutputParser(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501
|
def test_output_fixing_parser_output_type(
|
||||||
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501
|
base_parser: BaseOutputParser,
|
||||||
|
) -> None:
|
||||||
|
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough())
|
||||||
assert parser.OutputType is base_parser.OutputType
|
assert parser.OutputType is base_parser.OutputType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
# Case: retry_chain.InputType does not have 'instructions' key
|
||||||
|
"2024/07/08",
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
PromptTemplate.from_template("{completion}\n{error}")
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_output_fixing_parser_parse_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
base_parser.Config.extra = Extra.allow # type: ignore
|
||||||
|
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||||
|
# NOTE: get_format_instructions of some parsers behave randomly
|
||||||
|
instructions = base_parser.get_format_instructions()
|
||||||
|
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
|
||||||
|
# test
|
||||||
|
parser = OutputFixingParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert parser.parse(input) == expected
|
||||||
|
invoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
instructions=base_parser.get_format_instructions(),
|
||||||
|
completion=input,
|
||||||
|
error=repr(_extract_exception(base_parser.parse, input)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_FIX_PROMPT | RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
# Case: retry_chain.InputType does not have 'instructions' key
|
||||||
|
"2024/07/08",
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
PromptTemplate.from_template("{completion}\n{error}")
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_output_fixing_parser_aparse_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
base_parser.Config.extra = Extra.allow # type: ignore
|
||||||
|
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||||
|
# NOTE: get_format_instructions of some parsers behave randomly
|
||||||
|
instructions = base_parser.get_format_instructions()
|
||||||
|
object.__setattr__(base_parser, "get_format_instructions", lambda: instructions)
|
||||||
|
# test
|
||||||
|
parser = OutputFixingParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert (await parser.aparse(input)) == expected
|
||||||
|
ainvoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
instructions=base_parser.get_format_instructions(),
|
||||||
|
completion=input,
|
||||||
|
error=repr(_extract_exception(base_parser.parse, input)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_exception(
|
||||||
|
func: Callable[..., Any],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
return None
|
||||||
|
@ -1,24 +1,29 @@
|
|||||||
from typing import Any
|
from datetime import datetime as dt
|
||||||
|
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.prompt_values import StringPromptValue
|
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||||
from langchain.output_parsers.retry import (
|
from langchain.output_parsers.retry import (
|
||||||
|
NAIVE_RETRY_PROMPT,
|
||||||
|
NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
RetryOutputParser,
|
RetryOutputParser,
|
||||||
RetryWithErrorOutputParser,
|
RetryWithErrorOutputParser,
|
||||||
)
|
)
|
||||||
|
from langchain.pydantic_v1 import Extra
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||||
parse_count: int = 0 # Number of times parse has been called
|
parse_count: int = 0 # Number of times parse has been called
|
||||||
attemp_count_before_success: (
|
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||||
int # Number of times to fail before succeeding # noqa
|
|
||||||
)
|
|
||||||
error_msg: str = "error"
|
error_msg: str = "error"
|
||||||
|
|
||||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||||
@ -37,7 +42,7 @@ def test_retry_output_parser_parse_with_prompt() -> None:
|
|||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
legacy=False,
|
legacy=False,
|
||||||
)
|
)
|
||||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||||
assert actual == "parsed"
|
assert actual == "parsed"
|
||||||
assert base_parser.parse_count == n + 1
|
assert base_parser.parse_count == n + 1
|
||||||
|
|
||||||
@ -82,7 +87,7 @@ async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
|
|||||||
legacy=False,
|
legacy=False,
|
||||||
)
|
)
|
||||||
with pytest.raises(OutputParserException):
|
with pytest.raises(OutputParserException):
|
||||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||||
assert base_parser.parse_count == n
|
assert base_parser.parse_count == n
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +126,7 @@ def test_retry_with_error_output_parser_parse_with_prompt() -> None:
|
|||||||
max_retries=n, # n times to retry, that is, (n+1) times call
|
max_retries=n, # n times to retry, that is, (n+1) times call
|
||||||
legacy=False,
|
legacy=False,
|
||||||
)
|
)
|
||||||
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||||
assert actual == "parsed"
|
assert actual == "parsed"
|
||||||
assert base_parser.parse_count == n + 1
|
assert base_parser.parse_count == n + 1
|
||||||
|
|
||||||
@ -156,7 +161,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
|
|||||||
assert base_parser.parse_count == n + 1
|
assert base_parser.parse_count == n + 1
|
||||||
|
|
||||||
|
|
||||||
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501
|
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
||||||
n: int = 5 # Success on the (n+1)-th attempt
|
n: int = 5 # Success on the (n+1)-th attempt
|
||||||
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
|
||||||
parser = RetryWithErrorOutputParser(
|
parser = RetryWithErrorOutputParser(
|
||||||
@ -166,7 +171,7 @@ async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None:
|
|||||||
legacy=False,
|
legacy=False,
|
||||||
)
|
)
|
||||||
with pytest.raises(OutputParserException):
|
with pytest.raises(OutputParserException):
|
||||||
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
|
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy"))
|
||||||
assert base_parser.parse_count == n
|
assert base_parser.parse_count == n
|
||||||
|
|
||||||
|
|
||||||
@ -196,3 +201,177 @@ def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
|
|||||||
)
|
)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
parser.parse("completion")
|
parser.parse("completion")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,prompt,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
StringPromptValue(text="dummy"),
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_RETRY_PROMPT
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_retry_output_parser_parse_with_prompt_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
prompt: PromptValue,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||||
|
# test
|
||||||
|
parser = RetryOutputParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert parser.parse_with_prompt(input, prompt) == expected
|
||||||
|
invoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
prompt=prompt.to_string(),
|
||||||
|
completion=input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,prompt,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
StringPromptValue(text="dummy"),
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_RETRY_PROMPT
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_retry_output_parser_aparse_with_prompt_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
prompt: PromptValue,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||||
|
# test
|
||||||
|
parser = RetryOutputParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||||
|
ainvoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
prompt=prompt.to_string(),
|
||||||
|
completion=input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,prompt,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
StringPromptValue(text="dummy"),
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_RETRY_WITH_ERROR_PROMPT
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_retry_with_error_output_parser_parse_with_prompt_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
prompt: PromptValue,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
invoke_spy = mocker.spy(retry_chain, "invoke")
|
||||||
|
# test
|
||||||
|
parser = RetryWithErrorOutputParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert parser.parse_with_prompt(input, prompt) == expected
|
||||||
|
invoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
prompt=prompt.to_string(),
|
||||||
|
completion=input,
|
||||||
|
error=repr(_extract_exception(base_parser.parse, input)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input,prompt,base_parser,retry_chain,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"2024/07/08",
|
||||||
|
StringPromptValue(text="dummy"),
|
||||||
|
DatetimeOutputParser(),
|
||||||
|
NAIVE_RETRY_WITH_ERROR_PROMPT
|
||||||
|
| RunnableLambda(lambda _: "2024-07-08T00:00:00.000000Z"),
|
||||||
|
dt(2024, 7, 8),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_retry_with_error_output_parser_aparse_with_prompt_with_retry_chain(
|
||||||
|
input: str,
|
||||||
|
prompt: PromptValue,
|
||||||
|
base_parser: BaseOutputParser[T],
|
||||||
|
retry_chain: Runnable[Dict[str, Any], str],
|
||||||
|
expected: T,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
# preparation
|
||||||
|
# NOTE: Extra.allow is necessary in order to use spy and mock
|
||||||
|
retry_chain.Config.extra = Extra.allow # type: ignore
|
||||||
|
ainvoke_spy = mocker.spy(retry_chain, "ainvoke")
|
||||||
|
# test
|
||||||
|
parser = RetryWithErrorOutputParser(
|
||||||
|
parser=base_parser,
|
||||||
|
retry_chain=retry_chain,
|
||||||
|
legacy=False,
|
||||||
|
)
|
||||||
|
assert (await parser.aparse_with_prompt(input, prompt)) == expected
|
||||||
|
ainvoke_spy.assert_called_once_with(
|
||||||
|
dict(
|
||||||
|
prompt=prompt.to_string(),
|
||||||
|
completion=input,
|
||||||
|
error=repr(_extract_exception(base_parser.parse, input)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_exception(
|
||||||
|
func: Callable[..., Any],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user