mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
fix(langchain): preserve supplied llm in FlareChain.from_llm
(#32847)
This commit is contained in:
@@ -249,7 +249,7 @@ class FlareChain(Chain):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: Optional[BaseLanguageModel],
|
||||||
max_generation_len: int = 32,
|
max_generation_len: int = 32,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> FlareChain:
|
) -> FlareChain:
|
||||||
@@ -272,11 +272,36 @@ class FlareChain(Chain):
|
|||||||
"pip install langchain-openai"
|
"pip install langchain-openai"
|
||||||
)
|
)
|
||||||
raise ImportError(msg) from e
|
raise ImportError(msg) from e
|
||||||
|
# Preserve supplied llm instead of always creating a new ChatOpenAI.
|
||||||
|
# Enforce ChatOpenAI requirement (token logprobs needed for FLARE).
|
||||||
|
if llm is None:
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
max_completion_tokens=max_generation_len,
|
max_completion_tokens=max_generation_len,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(llm, ChatOpenAI):
|
||||||
|
msg = (
|
||||||
|
f"FlareChain.from_llm requires ChatOpenAI; got "
|
||||||
|
f"{type(llm).__name__}."
|
||||||
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
if not getattr(llm, "logprobs", False): # attribute presence may vary
|
||||||
|
msg = (
|
||||||
|
"Provided ChatOpenAI instance must be constructed with "
|
||||||
|
"logprobs=True for FlareChain."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
current_max = getattr(llm, "max_completion_tokens", None)
|
||||||
|
if current_max is not None and current_max != max_generation_len:
|
||||||
|
logger.debug(
|
||||||
|
"FlareChain.from_llm: supplied llm max_completion_tokens=%s "
|
||||||
|
"differs from requested max_generation_len=%s; "
|
||||||
|
"leaving model unchanged.",
|
||||||
|
current_max,
|
||||||
|
max_generation_len,
|
||||||
|
)
|
||||||
response_chain = PROMPT | llm
|
response_chain = PROMPT | llm
|
||||||
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||||
return cls(
|
return cls(
|
||||||
|
51
libs/langchain/tests/unit_tests/chains/test_flare.py
Normal file
51
libs/langchain/tests/unit_tests/chains/test_flare.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Tests for FlareChain.from_llm preserving supplied ChatOpenAI instance."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.runnables import RunnableSequence
|
||||||
|
|
||||||
|
from langchain.chains.flare.base import FlareChain
|
||||||
|
|
||||||
|
|
||||||
|
class _EmptyRetriever(BaseRetriever):
|
||||||
|
"""Minimal no-op retriever used only for constructing FlareChain in tests."""
|
||||||
|
|
||||||
|
def _get_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
|
||||||
|
del query # mark used
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
|
||||||
|
del query # mark used
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_llm_rejects_non_chatopenai() -> None:
|
||||||
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
FlareChain.from_llm(Dummy()) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("langchain_openai")
|
||||||
|
def test_from_llm_uses_supplied_chatopenai(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
try:
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
pytest.skip("langchain-openai not installed")
|
||||||
|
|
||||||
|
# Provide dummy API key to satisfy constructor env validation.
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "TEST")
|
||||||
|
|
||||||
|
supplied = ChatOpenAI(temperature=0.51, logprobs=True, max_completion_tokens=21)
|
||||||
|
chain = FlareChain.from_llm(
|
||||||
|
supplied,
|
||||||
|
max_generation_len=32,
|
||||||
|
retriever=_EmptyRetriever(),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_in_chain = cast("RunnableSequence", chain.question_generator_chain).steps[1]
|
||||||
|
assert llm_in_chain is supplied
|
Reference in New Issue
Block a user