mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 20:49:17 +00:00
langchain: Make OutputFixingParser.from_llm() create a useable retry chain (#24687)
Description: OutputFixingParser.from_llm() creates a retry chain that returns a Generation instance, when it should actually just return a string. Issue: https://github.com/langchain-ai/langchain/issues/24600 Twitter handle: scribu --------- Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
b3a23ddf93
commit
174e7d2ab2
@ -3,10 +3,9 @@ from __future__ import annotations
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
@ -42,7 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
llm: Runnable,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
||||
max_retries: int = 1,
|
||||
@ -58,7 +57,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
OutputFixingParser
|
||||
"""
|
||||
chain = prompt | llm
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse(self, completion: str) -> T:
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
@ -82,7 +82,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
RetryOutputParser
|
||||
"""
|
||||
chain = prompt | llm
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
import pytest
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from pytest_mock import MockerFixture
|
||||
@ -63,6 +64,22 @@ def test_output_fixing_parser_parse(
|
||||
# TODO: test whether "instructions" is passed to the retry_chain
|
||||
|
||||
|
||||
def test_output_fixing_parser_from_llm() -> None:
|
||||
def fake_llm(prompt: str) -> AIMessage:
|
||||
return AIMessage("2024-07-08T00:00:00.000000Z")
|
||||
|
||||
llm = RunnableLambda(fake_llm)
|
||||
|
||||
n = 1
|
||||
parser = OutputFixingParser.from_llm(
|
||||
llm=llm,
|
||||
parser=DatetimeOutputParser(),
|
||||
max_retries=n,
|
||||
)
|
||||
|
||||
assert parser.parse("not a date")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_parser",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user