langchain: Make OutputFixingParser.from_llm() create a useable retry chain (#24687)

Description: OutputFixingParser.from_llm() creates a retry chain that
returns a Generation instance, when it should actually just return a
string.
Issue: https://github.com/langchain-ai/langchain/issues/24600
Twitter handle: scribu

---------

Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
Cristi Burcă 2024-07-26 21:55:47 +01:00 committed by GitHub
parent b3a23ddf93
commit 174e7d2ab2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 7 deletions

View File

@ -3,10 +3,9 @@ from __future__ import annotations
from typing import Any, TypeVar, Union from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import RunnableSerializable from langchain_core.runnables import Runnable, RunnableSerializable
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
@ -42,7 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLanguageModel, llm: Runnable,
parser: BaseOutputParser[T], parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
max_retries: int = 1, max_retries: int = 1,
@ -58,7 +57,7 @@ class OutputFixingParser(BaseOutputParser[T]):
Returns: Returns:
OutputFixingParser OutputFixingParser
""" """
chain = prompt | llm chain = prompt | llm | StrOutputParser()
return cls(parser=parser, retry_chain=chain, max_retries=max_retries) return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse(self, completion: str) -> T: def parse(self, completion: str) -> T:

View File

@ -4,7 +4,7 @@ from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableSerializable from langchain_core.runnables import RunnableSerializable
@ -82,7 +82,7 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns: Returns:
RetryOutputParser RetryOutputParser
""" """
chain = prompt | llm chain = prompt | llm | StrOutputParser()
return cls(parser=parser, retry_chain=chain, max_retries=max_retries) return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar
import pytest import pytest
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@ -63,6 +64,22 @@ def test_output_fixing_parser_parse(
# TODO: test whether "instructions" is passed to the retry_chain # TODO: test whether "instructions" is passed to the retry_chain
def test_output_fixing_parser_from_llm() -> None:
def fake_llm(prompt: str) -> AIMessage:
return AIMessage("2024-07-08T00:00:00.000000Z")
llm = RunnableLambda(fake_llm)
n = 1
parser = OutputFixingParser.from_llm(
llm=llm,
parser=DatetimeOutputParser(),
max_retries=n,
)
assert parser.parse("not a date")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"base_parser", "base_parser",
[ [