diff --git a/libs/langchain/langchain/output_parsers/combining.py b/libs/langchain/langchain/output_parsers/combining.py index 1222c2208db..3159a014e8d 100644 --- a/libs/langchain/langchain/output_parsers/combining.py +++ b/libs/langchain/langchain/output_parsers/combining.py @@ -6,7 +6,7 @@ from langchain_core.output_parsers import BaseOutputParser from langchain_core.pydantic_v1 import root_validator -class CombiningOutputParser(BaseOutputParser): +class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]): """Combine multiple output parsers into one.""" parsers: List[BaseOutputParser] diff --git a/libs/langchain/langchain/output_parsers/enum.py b/libs/langchain/langchain/output_parsers/enum.py index 4396a5291bf..b21d80834bf 100644 --- a/libs/langchain/langchain/output_parsers/enum.py +++ b/libs/langchain/langchain/output_parsers/enum.py @@ -1,12 +1,12 @@ from enum import Enum -from typing import Any, Dict, List, Type +from typing import Dict, List, Type from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser from langchain_core.pydantic_v1 import root_validator -class EnumOutputParser(BaseOutputParser): +class EnumOutputParser(BaseOutputParser[Enum]): """Parse an output that is one of a set of values.""" enum: Type[Enum] @@ -23,7 +23,7 @@ class EnumOutputParser(BaseOutputParser): def _valid_values(self) -> List[str]: return [e.value for e in self.enum] - def parse(self, response: str) -> Any: + def parse(self, response: str) -> Enum: try: return self.enum(response.strip()) except ValueError: @@ -34,3 +34,7 @@ class EnumOutputParser(BaseOutputParser): def get_format_instructions(self) -> str: return f"Select one of the following options: {', '.join(self._valid_values)}" + + @property + def OutputType(self) -> Type[Enum]: + return self.enum diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 7b6e5a34426..849200105d9 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Any, TypeVar +from typing import Any, TypeVar, Union from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import BasePromptTemplate +from langchain_core.runnables import RunnableSerializable from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT @@ -22,10 +23,12 @@ class OutputFixingParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains - retry_chain: Any - """The LLMChain to use to retry the completion.""" + retry_chain: Union[RunnableSerializable, Any] + """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" + legacy: bool = True + """Whether to use the run or arun method of the retry_chain.""" @classmethod def from_llm( @@ -46,9 +49,7 @@ class OutputFixingParser(BaseOutputParser[T]): Returns: OutputFixingParser """ - from langchain.chains.llm import LLMChain - - chain = LLMChain(llm=llm, prompt=prompt) + chain = prompt | llm return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse(self, completion: str) -> T: @@ -62,11 +63,29 @@ class OutputFixingParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = self.retry_chain.run( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), - ) + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) + else: + try: + completion = self.retry_chain.invoke( + dict( + instructions=self.parser.get_format_instructions(), # noqa: E501 + input=completion, + error=repr(e), + ) + ) + except (NotImplementedError, AttributeError): + # Case: self.parser does not have get_format_instructions # noqa: E501 + completion = self.retry_chain.invoke( + dict( + input=completion, + error=repr(e), + ) + ) raise OutputParserException("Failed to parse") @@ -81,11 +100,29 @@ class OutputFixingParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = await self.retry_chain.arun( - instructions=self.parser.get_format_instructions(), - completion=completion, - error=repr(e), - ) + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + instructions=self.parser.get_format_instructions(), # noqa: E501 + completion=completion, + error=repr(e), + ) + else: + try: + completion = await self.retry_chain.ainvoke( + dict( + instructions=self.parser.get_format_instructions(), # noqa: E501 + input=completion, + error=repr(e), + ) + ) + except (NotImplementedError, AttributeError): + # Case: self.parser does not have get_format_instructions # noqa: E501 + completion = await self.retry_chain.ainvoke( + dict( + input=completion, + error=repr(e), + ) + ) raise OutputParserException("Failed to parse") @@ -95,3 +132,7 @@ class OutputFixingParser(BaseOutputParser[T]): @property def _type(self) -> str: return "output_fixing" + + @property + def OutputType(self) -> type[T]: + return self.parser.OutputType diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index 4c0cb177d02..3447767c088 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -10,7 +10,7 @@ from langchain.output_parsers.format_instructions import ( ) -class PandasDataFrameOutputParser(BaseOutputParser): +class PandasDataFrameOutputParser(BaseOutputParser[Dict[str, Any]]): """Parse an output using Pandas DataFrame format.""" """The Pandas DataFrame to parse.""" diff --git a/libs/langchain/langchain/output_parsers/regex.py b/libs/langchain/langchain/output_parsers/regex.py index ea8b053e159..5add60b1b28 100644 --- a/libs/langchain/langchain/output_parsers/regex.py +++ b/libs/langchain/langchain/output_parsers/regex.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional from langchain_core.output_parsers import BaseOutputParser -class RegexParser(BaseOutputParser): +class RegexParser(BaseOutputParser[Dict[str, str]]): """Parse the output of an LLM call using a regex.""" @classmethod diff --git a/libs/langchain/langchain/output_parsers/regex_dict.py b/libs/langchain/langchain/output_parsers/regex_dict.py index 1b390485da3..df40c7683fd 100644 --- a/libs/langchain/langchain/output_parsers/regex_dict.py +++ b/libs/langchain/langchain/output_parsers/regex_dict.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain_core.output_parsers import BaseOutputParser -class RegexDictParser(BaseOutputParser): +class RegexDictParser(BaseOutputParser[Dict[str, str]]): """Parse the output of an LLM call into a Dictionary using a regex.""" regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private: diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index 571362f0120..b82f1796571 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, TypeVar +from typing import Any, TypeVar, Union from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.runnables import RunnableSerializable NAIVE_COMPLETION_RETRY = """Prompt: {prompt} @@ -43,10 +44,12 @@ class RetryOutputParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" # Should be an LLMChain but we want to avoid top-level imports from langchain.chains - retry_chain: Any - """The LLMChain to use to retry the completion.""" + retry_chain: Union[RunnableSerializable, Any] + """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" + legacy: bool = True + """Whether to use the run or arun method of the retry_chain.""" @classmethod def from_llm( @@ -67,9 +70,7 @@ class RetryOutputParser(BaseOutputParser[T]): Returns: RetryOutputParser """ - from langchain.chains.llm import LLMChain - - chain = LLMChain(llm=llm, prompt=prompt) + chain = prompt | llm return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: @@ -92,9 +93,19 @@ class RetryOutputParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = self.retry_chain.run( - prompt=prompt_value.to_string(), completion=completion - ) + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + else: + completion = self.retry_chain.invoke( + dict( + prompt=prompt_value.to_string(), + input=completion, + ) + ) raise OutputParserException("Failed to parse") @@ -118,9 +129,19 @@ class RetryOutputParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), completion=completion - ) + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + else: + completion = await self.retry_chain.ainvoke( + dict( + prompt=prompt_value.to_string(), + input=completion, + ) + ) raise OutputParserException("Failed to parse") @@ -136,6 +157,10 @@ class RetryOutputParser(BaseOutputParser[T]): def _type(self) -> str: return "retry" + @property + def OutputType(self) -> type[T]: + return self.parser.OutputType + class RetryWithErrorOutputParser(BaseOutputParser[T]): """Wrap a parser and try to fix parsing errors. @@ -149,11 +174,13 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): parser: BaseOutputParser[T] """The parser to use to parse the output.""" - # Should be an LLMChain but we want to avoid top-level imports from langchain.chains - retry_chain: Any - """The LLMChain to use to retry the completion.""" + # Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501 + retry_chain: Union[RunnableSerializable, Any] + """The RunnableSerializable to use to retry the completion (Legacy: LLMChain).""" max_retries: int = 1 """The maximum number of times to retry the parse.""" + legacy: bool = True + """Whether to use the run or arun method of the retry_chain.""" @classmethod def from_llm( @@ -174,12 +201,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): Returns: A RetryWithErrorOutputParser. """ - from langchain.chains.llm import LLMChain - - chain = LLMChain(llm=llm, prompt=prompt) + chain = prompt | llm return cls(parser=parser, retry_chain=chain, max_retries=max_retries) - def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501 retries = 0 while retries <= self.max_retries: @@ -190,11 +215,20 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = self.retry_chain.run( - prompt=prompt_value.to_string(), - completion=completion, - error=repr(e), - ) + if self.legacy and hasattr(self.retry_chain, "run"): + completion = self.retry_chain.run( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + else: + completion = self.retry_chain.invoke( + dict( + input=completion, + prompt=prompt_value.to_string(), + error=repr(e), + ) + ) raise OutputParserException("Failed to parse") @@ -209,11 +243,20 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): raise e else: retries += 1 - completion = await self.retry_chain.arun( - prompt=prompt_value.to_string(), - completion=completion, - error=repr(e), - ) + if self.legacy and hasattr(self.retry_chain, "arun"): + completion = await self.retry_chain.arun( + prompt=prompt_value.to_string(), + completion=completion, + error=repr(e), + ) + else: + completion = await self.retry_chain.ainvoke( + dict( + prompt=prompt_value.to_string(), + input=completion, + error=repr(e), + ) + ) raise OutputParserException("Failed to parse") @@ -228,3 +271,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): @property def _type(self) -> str: return "retry_with_error" + + @property + def OutputType(self) -> type[T]: + return self.parser.OutputType diff --git a/libs/langchain/langchain/output_parsers/structured.py b/libs/langchain/langchain/output_parsers/structured.py index 57abb7f2c88..097e1a7170f 100644 --- a/libs/langchain/langchain/output_parsers/structured.py +++ b/libs/langchain/langchain/output_parsers/structured.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List +from typing import Any, Dict, List from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers.json import parse_and_check_json_markdown @@ -31,7 +31,7 @@ def _get_sub_string(schema: ResponseSchema) -> str: ) -class StructuredOutputParser(BaseOutputParser): +class StructuredOutputParser(BaseOutputParser[Dict[str, Any]]): """Parse the output of an LLM call to a structured output.""" response_schemas: List[ResponseSchema] @@ -92,7 +92,7 @@ class StructuredOutputParser(BaseOutputParser): else: return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) - def parse(self, text: str) -> Any: + def parse(self, text: str) -> Dict[str, Any]: expected_keys = [rs.name for rs in self.response_schemas] return parse_and_check_json_markdown(text, expected_keys) diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index 21bcf359a2c..e7c071eb400 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -60,3 +60,7 @@ class YamlOutputParser(BaseOutputParser[T]): @property def _type(self) -> str: return "yaml" + + @property + def OutputType(self) -> Type[T]: + return self.pydantic_object diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py index bae5992875f..2ccfc2f1b2e 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_boolean_parser.py @@ -39,3 +39,8 @@ def test_boolean_output_parser_parse() -> None: # Bad input with pytest.raises(ValueError): parser.parse("BOOM") + + +def test_boolean_output_parser_output_type() -> None: + """Test the output type of the boolean output parser is a boolean.""" + assert BooleanOutputParser().OutputType == bool diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py index 21a3ab6a92a..a2d757e8863 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py @@ -1,4 +1,6 @@ """Test in memory docstore.""" +from typing import Any, Dict + from langchain.output_parsers.combining import CombiningOutputParser from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser @@ -43,3 +45,27 @@ def test_combining_dict_result() -> None: combining_parser = CombiningOutputParser(parsers=parsers) result_dict = combining_parser.parse(DEF_README) assert DEF_EXPECTED_RESULT == result_dict + + +def test_combining_output_parser_output_type() -> None: + """Test combining output parser output type is Dict[str, Any].""" + parsers = [ + StructuredOutputParser( + response_schemas=[ + ResponseSchema( + name="answer", description="answer to the user's question" + ), + ResponseSchema( + name="source", + description="source used to answer the user's question", + ), + ] + ), + RegexParser( + regex=r"Confidence: (A|B|C), Explanation: (.*)", + output_keys=["confidence", "explanation"], + default_output_key="noConfidence", + ), + ] + combining_parser = CombiningOutputParser(parsers=parsers) + assert combining_parser.OutputType is Dict[str, Any] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py index 57fe3e0717b..b385b65b530 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_enum_parser.py @@ -30,3 +30,8 @@ def test_enum_output_parser_parse() -> None: assert False, "Should have raised OutputParserException" except OutputParserException: pass + + +def test_enum_output_parser_output_type() -> None: + """Test the output type of the enum output parser is the expected enum.""" + assert EnumOutputParser(enum=Colors).OutputType is Colors diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py new file mode 100644 index 00000000000..0f1eaf94130 --- /dev/null +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -0,0 +1,121 @@ +from typing import Any + +import pytest +from langchain_core.exceptions import OutputParserException +from langchain_core.runnables import RunnablePassthrough + +from langchain.output_parsers.boolean import BooleanOutputParser +from langchain.output_parsers.datetime import DatetimeOutputParser +from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser + + +class SuccessfulParseAfterRetries(BaseOutputParser[str]): + parse_count: int = 0 # Number of times parse has been called + attemp_count_before_success: int # Number of times to fail before succeeding # noqa + + def parse(self, *args: Any, **kwargs: Any) -> str: + self.parse_count += 1 + if self.parse_count <= self.attemp_count_before_success: + raise OutputParserException("error") + return "parsed" + + +class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa + def get_format_instructions(self) -> str: + return "instructions" + + +@pytest.mark.parametrize( + "base_parser", + [ + SuccessfulParseAfterRetries(attemp_count_before_success=5), + SuccessfulParseAfterRetriesWithGetFormatInstructions( + attemp_count_before_success=5 + ), # noqa: E501 + ], +) +def test_output_fixing_parser_parse( + base_parser: SuccessfulParseAfterRetries, +) -> None: + # preparation + n: int = ( + base_parser.attemp_count_before_success + ) # Success on the (n+1)-th attempt # noqa + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = OutputFixingParser( + parser=base_parser, + max_retries=n, # n times to retry, that is, (n+1) times call + retry_chain=RunnablePassthrough(), + legacy=False, + ) + # test + assert parser.parse("completion") == "parsed" + assert base_parser.parse_count == n + 1 + # TODO: test whether "instructions" is passed to the retry_chain + + +@pytest.mark.parametrize( + "base_parser", + [ + SuccessfulParseAfterRetries(attemp_count_before_success=5), + SuccessfulParseAfterRetriesWithGetFormatInstructions( + attemp_count_before_success=5 + ), # noqa: E501 + ], +) +async def test_output_fixing_parser_aparse( + base_parser: SuccessfulParseAfterRetries, +) -> None: + n: int = ( + base_parser.attemp_count_before_success + ) # Success on the (n+1)-th attempt # noqa + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = OutputFixingParser( + parser=base_parser, + max_retries=n, # n times to retry, that is, (n+1) times call + retry_chain=RunnablePassthrough(), + legacy=False, + ) + assert (await parser.aparse("completion")) == "parsed" + assert base_parser.parse_count == n + 1 + # TODO: test whether "instructions" is passed to the retry_chain + + +def test_output_fixing_parser_parse_fail() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = OutputFixingParser( + parser=base_parser, + max_retries=n - 1, # n-1 times to retry, that is, n times call + retry_chain=RunnablePassthrough(), + legacy=False, + ) + with pytest.raises(OutputParserException): + parser.parse("completion") + assert base_parser.parse_count == n + + +async def test_output_fixing_parser_aparse_fail() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = OutputFixingParser( + parser=base_parser, + max_retries=n - 1, # n-1 times to retry, that is, n times call + retry_chain=RunnablePassthrough(), + legacy=False, + ) + with pytest.raises(OutputParserException): + await parser.aparse("completion") + assert base_parser.parse_count == n + + +@pytest.mark.parametrize( + "base_parser", + [ + BooleanOutputParser(), + DatetimeOutputParser(), + ], +) +def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501 + parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501 + assert parser.OutputType is base_parser.OutputType diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py index 1a61ba2c498..6b20614dd85 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py @@ -1,4 +1,6 @@ """Test PandasDataframeParser""" +from typing import Any, Dict + import pandas as pd from langchain_core.exceptions import OutputParserException @@ -108,3 +110,8 @@ def test_pandas_output_parser_invalid_special_op() -> None: assert False, "Should have raised OutputParserException" except OutputParserException: assert True + + +def test_pandas_output_parser_output_type() -> None: + """Test the output type of the pandas dataframe output parser is a pandas dataframe.""" # noqa: E501 + assert parser.OutputType is Dict[str, Any] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py index ee08ce26168..f8c7d42c709 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -60,6 +60,7 @@ def test_pydantic_output_parser() -> None: result = pydantic_parser.parse(DEF_RESULT) print("parse_result:", result) # noqa: T201 assert DEF_EXPECTED_RESULT == result + assert pydantic_parser.OutputType is TestModel def test_pydantic_output_parser_fail() -> None: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py new file mode 100644 index 00000000000..ef434b4ba79 --- /dev/null +++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py @@ -0,0 +1,38 @@ +from typing import Dict + +from langchain.output_parsers.regex import RegexParser + +# NOTE: The almost same constant variables in ./test_combining_parser.py +DEF_EXPECTED_RESULT = { + "confidence": "A", + "explanation": "Paris is the capital of France according to Wikipedia.", +} + +DEF_README = """```json +{ + "answer": "Paris", + "source": "https://en.wikipedia.org/wiki/France" +} +``` + +//Confidence: A, Explanation: Paris is the capital of France according to Wikipedia.""" + + +def test_regex_parser_parse() -> None: + """Test regex parser parse.""" + parser = RegexParser( + regex=r"Confidence: (A|B|C), Explanation: (.*)", + output_keys=["confidence", "explanation"], + default_output_key="noConfidence", + ) + assert DEF_EXPECTED_RESULT == parser.parse(DEF_README) + + +def test_regex_parser_output_type() -> None: + """Test regex parser output type is Dict[str, str].""" + parser = RegexParser( + regex=r"Confidence: (A|B|C), Explanation: (.*)", + output_keys=["confidence", "explanation"], + default_output_key="noConfidence", + ) + assert parser.OutputType is Dict[str, str] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py index c07d0c85671..b9cc1f38a35 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py @@ -1,5 +1,7 @@ """Test in memory docstore.""" +from typing import Dict + from langchain.output_parsers.regex_dict import RegexDictParser DEF_EXPECTED_RESULT = {"action": "Search", "action_input": "How to use this class?"} @@ -36,3 +38,11 @@ def test_regex_dict_result() -> None: result_dict = regex_dict_parser.parse(DEF_README) print("parse_result:", result_dict) # noqa: T201 assert DEF_EXPECTED_RESULT == result_dict + + +def test_regex_dict_output_type() -> None: + """Test regex dict output type.""" + regex_dict_parser = RegexDictParser( + output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, no_update_value="N/A" + ) + assert regex_dict_parser.OutputType is Dict[str, str] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py new file mode 100644 index 00000000000..161ba32a980 --- /dev/null +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -0,0 +1,196 @@ +from typing import Any + +import pytest +from langchain_core.prompt_values import StringPromptValue +from langchain_core.runnables import RunnablePassthrough + +from langchain.output_parsers.boolean import BooleanOutputParser +from langchain.output_parsers.datetime import DatetimeOutputParser +from langchain.output_parsers.retry import ( + BaseOutputParser, + OutputParserException, + RetryOutputParser, + RetryWithErrorOutputParser, +) + + +class SuccessfulParseAfterRetries(BaseOutputParser[str]): + parse_count: int = 0 # Number of times parse has been called + attemp_count_before_success: int # Number of times to fail before succeeding # noqa + error_msg: str = "error" + + def parse(self, *args: Any, **kwargs: Any) -> str: + self.parse_count += 1 + if self.parse_count <= self.attemp_count_before_success: + raise OutputParserException(self.error_msg) + return "parsed" + + +def test_retry_output_parser_parse_with_prompt() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n, # n times to retry, that is, (n+1) times call + legacy=False, + ) + actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + assert actual == "parsed" + assert base_parser.parse_count == n + 1 + + +def test_retry_output_parser_parse_with_prompt_fail() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n - 1, # n-1 times to retry, that is, n times call + legacy=False, + ) + with pytest.raises(OutputParserException): + parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) + assert base_parser.parse_count == n + + +async def test_retry_output_parser_aparse_with_prompt() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n, # n times to retry, that is, (n+1) times call + legacy=False, + ) + actual = await parser.aparse_with_prompt( + "completion", StringPromptValue(text="dummy") + ) + assert actual == "parsed" + assert base_parser.parse_count == n + 1 + + +async def test_retry_output_parser_aparse_with_prompt_fail() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n - 1, # n-1 times to retry, that is, n times call + legacy=False, + ) + with pytest.raises(OutputParserException): + await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + assert base_parser.parse_count == n + + +@pytest.mark.parametrize( + "base_parser", + [ + BooleanOutputParser(), + DatetimeOutputParser(), + ], +) +def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None: + parser = RetryOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + legacy=False, + ) + assert parser.OutputType is base_parser.OutputType + + +def test_retry_output_parser_parse_is_not_implemented() -> None: + parser = RetryOutputParser( + parser=BooleanOutputParser(), + retry_chain=RunnablePassthrough(), + legacy=False, + ) + with pytest.raises(NotImplementedError): + parser.parse("completion") + + +def test_retry_with_error_output_parser_parse_with_prompt() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n, # n times to retry, that is, (n+1) times call + legacy=False, + ) + actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + assert actual == "parsed" + assert base_parser.parse_count == n + 1 + + +def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n - 1, # n-1 times to retry, that is, n times call + legacy=False, + ) + with pytest.raises(OutputParserException): + parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) + assert base_parser.parse_count == n + + +async def test_retry_with_error_output_parser_aparse_with_prompt() -> None: + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n, # n times to retry, that is, (n+1) times call + legacy=False, + ) + actual = await parser.aparse_with_prompt( + "completion", StringPromptValue(text="dummy") + ) + assert actual == "parsed" + assert base_parser.parse_count == n + 1 + + +async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501 + n: int = 5 # Success on the (n+1)-th attempt + base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n) + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + max_retries=n - 1, # n-1 times to retry, that is, n times call + legacy=False, + ) + with pytest.raises(OutputParserException): + await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501 + assert base_parser.parse_count == n + + +@pytest.mark.parametrize( + "base_parser", + [ + BooleanOutputParser(), + DatetimeOutputParser(), + ], +) +def test_retry_with_error_output_parser_output_type( + base_parser: BaseOutputParser, +) -> None: + parser = RetryWithErrorOutputParser( + parser=base_parser, + retry_chain=RunnablePassthrough(), + legacy=False, + ) + assert parser.OutputType is base_parser.OutputType + + +def test_retry_with_error_output_parser_parse_is_not_implemented() -> None: + parser = RetryWithErrorOutputParser( + parser=BooleanOutputParser(), + retry_chain=RunnablePassthrough(), + legacy=False, + ) + with pytest.raises(NotImplementedError): + parser.parse("completion") diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py index 8fec872eb26..857b427a410 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_structured_parser.py @@ -1,9 +1,12 @@ +from typing import Any, Dict + from langchain_core.exceptions import OutputParserException from langchain.output_parsers import ResponseSchema, StructuredOutputParser def test_parse() -> None: + """Test parsing structured output.""" response_schemas = [ ResponseSchema(name="name", description="desc"), ResponseSchema(name="age", description="desc"), @@ -24,3 +27,13 @@ def test_parse() -> None: pass # Test passes if OutputParserException is raised else: assert False, f"Expected OutputParserException, but got {parser.parse(text)}" + + +def test_output_type() -> None: + """Test the output type of the structured output parser is Dict[str, Any].""" + response_schemas = [ + ResponseSchema(name="name", description="desc"), + ResponseSchema(name="age", description="desc"), + ] + parser = StructuredOutputParser.from_response_schemas(response_schemas) + assert parser.OutputType == Dict[str, Any] diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index b48a353d8d3..065ca4aa96c 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -93,3 +93,9 @@ def test_yaml_output_parser_fail() -> None: assert "Failed to parse TestModel from completion" in str(e) else: assert False, "Expected OutputParserException" + + +def test_yaml_output_parser_output_type() -> None: + """Test YamlOutputParser OutputType.""" + yaml_parser = YamlOutputParser(pydantic_object=TestModel) + assert yaml_parser.OutputType is TestModel