diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 3a52f442626..a22bfff582b 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -3,10 +3,9 @@ from __future__ import annotations from typing import Any, TypeVar, Union from langchain_core.exceptions import OutputParserException -from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import BaseOutputParser +from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompts import BasePromptTemplate -from langchain_core.runnables import RunnableSerializable +from langchain_core.runnables import Runnable, RunnableSerializable from typing_extensions import TypedDict from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT @@ -42,7 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]): @classmethod def from_llm( cls, - llm: BaseLanguageModel, + llm: Runnable, parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, max_retries: int = 1, @@ -58,7 +57,7 @@ class OutputFixingParser(BaseOutputParser[T]): Returns: OutputFixingParser """ - chain = prompt | llm + chain = prompt | llm | StrOutputParser() return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse(self, completion: str) -> T: diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index 7d5a383903a..a896ae99bfd 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -4,7 +4,7 @@ from typing import Any, TypeVar, Union from langchain_core.exceptions import OutputParserException from langchain_core.language_models import BaseLanguageModel -from langchain_core.output_parsers import BaseOutputParser +from langchain_core.output_parsers import BaseOutputParser, StrOutputParser from langchain_core.prompt_values import PromptValue from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.runnables import RunnableSerializable @@ -82,7 +82,7 @@ class RetryOutputParser(BaseOutputParser[T]): Returns: RetryOutputParser """ - chain = prompt | llm + chain = prompt | llm | StrOutputParser() return cls(parser=parser, retry_chain=chain, max_retries=max_retries) def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index a98d823a45b..61d2d8a0c46 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar import pytest from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough from pytest_mock import MockerFixture @@ -63,6 +64,22 @@ def test_output_fixing_parser_parse( # TODO: test whether "instructions" is passed to the retry_chain +def test_output_fixing_parser_from_llm() -> None: + def fake_llm(prompt: str) -> AIMessage: + return AIMessage("2024-07-08T00:00:00.000000Z") + + llm = RunnableLambda(fake_llm) + + n = 1 + parser = OutputFixingParser.from_llm( + llm=llm, + parser=DatetimeOutputParser(), + max_retries=n, + ) + + assert parser.parse("not a date") + + @pytest.mark.parametrize( "base_parser", [