langchain: Make OutputFixingParser.from_llm() create a useable retry chain (#24687)

Description: OutputFixingParser.from_llm() creates a retry chain that
returns a Generation instance, when it should actually just return a
string.
Issue: https://github.com/langchain-ai/langchain/issues/24600
Twitter handle: scribu

---------

Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
Cristi Burcă 2024-07-26 21:55:47 +01:00 committed by GitHub
parent b3a23ddf93
commit 174e7d2ab2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 7 deletions

View File

@ -3,10 +3,9 @@ from __future__ import annotations
from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables import Runnable, RunnableSerializable
from typing_extensions import TypedDict
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
@ -42,7 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]):
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
llm: Runnable,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
max_retries: int = 1,
@ -58,7 +57,7 @@ class OutputFixingParser(BaseOutputParser[T]):
Returns:
OutputFixingParser
"""
chain = prompt | llm
chain = prompt | llm | StrOutputParser()
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse(self, completion: str) -> T:

View File

@ -4,7 +4,7 @@ from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableSerializable
@ -82,7 +82,7 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
RetryOutputParser
"""
chain = prompt | llm
chain = prompt | llm | StrOutputParser()
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pytest_mock import MockerFixture
@ -63,6 +64,22 @@ def test_output_fixing_parser_parse(
# TODO: test whether "instructions" is passed to the retry_chain
def test_output_fixing_parser_from_llm() -> None:
def fake_llm(prompt: str) -> AIMessage:
return AIMessage("2024-07-08T00:00:00.000000Z")
llm = RunnableLambda(fake_llm)
n = 1
parser = OutputFixingParser.from_llm(
llm=llm,
parser=DatetimeOutputParser(),
max_retries=n,
)
assert parser.parse("not a date")
@pytest.mark.parametrize(
"base_parser",
[