mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-03 13:43:24 +00:00
langchain: Make OutputFixingParser.from_llm() create a useable retry chain (#24687)
Description: OutputFixingParser.from_llm() creates a retry chain that returns a Generation instance, when it should actually just return a string. Issue: https://github.com/langchain-ai/langchain/issues/24600 Twitter handle: scribu --------- Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
b3a23ddf93
commit
174e7d2ab2
@ -3,10 +3,9 @@ from __future__ import annotations
|
|||||||
from typing import Any, TypeVar, Union
|
from typing import Any, TypeVar, Union
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||||
@ -42,7 +41,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: Runnable,
|
||||||
parser: BaseOutputParser[T],
|
parser: BaseOutputParser[T],
|
||||||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
||||||
max_retries: int = 1,
|
max_retries: int = 1,
|
||||||
@ -58,7 +57,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
Returns:
|
Returns:
|
||||||
OutputFixingParser
|
OutputFixingParser
|
||||||
"""
|
"""
|
||||||
chain = prompt | llm
|
chain = prompt | llm | StrOutputParser()
|
||||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||||
|
|
||||||
def parse(self, completion: str) -> T:
|
def parse(self, completion: str) -> T:
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, TypeVar, Union
|
|||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
from langchain_core.runnables import RunnableSerializable
|
from langchain_core.runnables import RunnableSerializable
|
||||||
@ -82,7 +82,7 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
Returns:
|
Returns:
|
||||||
RetryOutputParser
|
RetryOutputParser
|
||||||
"""
|
"""
|
||||||
chain = prompt | llm
|
chain = prompt | llm | StrOutputParser()
|
||||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||||
|
|
||||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
@ -63,6 +64,22 @@ def test_output_fixing_parser_parse(
|
|||||||
# TODO: test whether "instructions" is passed to the retry_chain
|
# TODO: test whether "instructions" is passed to the retry_chain
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_fixing_parser_from_llm() -> None:
|
||||||
|
def fake_llm(prompt: str) -> AIMessage:
|
||||||
|
return AIMessage("2024-07-08T00:00:00.000000Z")
|
||||||
|
|
||||||
|
llm = RunnableLambda(fake_llm)
|
||||||
|
|
||||||
|
n = 1
|
||||||
|
parser = OutputFixingParser.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
parser=DatetimeOutputParser(),
|
||||||
|
max_retries=n,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert parser.parse("not a date")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"base_parser",
|
"base_parser",
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user