fix(langchain): preserve supplied llm in FlareChain.from_llm (#32847)

This commit is contained in:
Gal Bloch
2025-09-09 16:41:23 +03:00
committed by GitHub
parent 714f74a847
commit 428c2ee6c5
2 changed files with 82 additions and 6 deletions

View File

@@ -249,7 +249,7 @@ class FlareChain(Chain):
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
llm: Optional[BaseLanguageModel],
max_generation_len: int = 32,
**kwargs: Any,
) -> FlareChain:
@@ -272,11 +272,36 @@ class FlareChain(Chain):
"pip install langchain-openai"
)
raise ImportError(msg) from e
llm = ChatOpenAI(
max_completion_tokens=max_generation_len,
logprobs=True,
temperature=0,
)
# Preserve supplied llm instead of always creating a new ChatOpenAI.
# Enforce ChatOpenAI requirement (token logprobs needed for FLARE).
if llm is None:
llm = ChatOpenAI(
max_completion_tokens=max_generation_len,
logprobs=True,
temperature=0,
)
else:
if not isinstance(llm, ChatOpenAI):
msg = (
f"FlareChain.from_llm requires ChatOpenAI; got "
f"{type(llm).__name__}."
)
raise TypeError(msg)
if not getattr(llm, "logprobs", False): # attribute presence may vary
msg = (
"Provided ChatOpenAI instance must be constructed with "
"logprobs=True for FlareChain."
)
raise ValueError(msg)
current_max = getattr(llm, "max_completion_tokens", None)
if current_max is not None and current_max != max_generation_len:
logger.debug(
"FlareChain.from_llm: supplied llm max_completion_tokens=%s "
"differs from requested max_generation_len=%s; "
"leaving model unchanged.",
current_max,
max_generation_len,
)
response_chain = PROMPT | llm
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
return cls(

View File

@@ -0,0 +1,51 @@
"""Tests for FlareChain.from_llm preserving supplied ChatOpenAI instance."""
from typing import cast
import pytest
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableSequence
from langchain.chains.flare.base import FlareChain
class _EmptyRetriever(BaseRetriever):
"""Minimal no-op retriever used only for constructing FlareChain in tests."""
def _get_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
del query # mark used
return []
async def _aget_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
del query # mark used
return []
def test_from_llm_rejects_non_chatopenai() -> None:
class Dummy:
pass
with pytest.raises(TypeError):
FlareChain.from_llm(Dummy()) # type: ignore[arg-type]
@pytest.mark.requires("langchain_openai")
def test_from_llm_uses_supplied_chatopenai(monkeypatch: pytest.MonkeyPatch) -> None:
try:
from langchain_openai import ChatOpenAI
except ImportError: # pragma: no cover
pytest.skip("langchain-openai not installed")
# Provide dummy API key to satisfy constructor env validation.
monkeypatch.setenv("OPENAI_API_KEY", "TEST")
supplied = ChatOpenAI(temperature=0.51, logprobs=True, max_completion_tokens=21)
chain = FlareChain.from_llm(
supplied,
max_generation_len=32,
retriever=_EmptyRetriever(),
)
llm_in_chain = cast("RunnableSequence", chain.question_generator_chain).steps[1]
assert llm_in_chain is supplied