diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 207b58b6645..d2621b3bfc3 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -249,7 +249,7 @@ class FlareChain(Chain): @classmethod def from_llm( cls, - llm: BaseLanguageModel, + llm: Optional[BaseLanguageModel], max_generation_len: int = 32, **kwargs: Any, ) -> FlareChain: @@ -272,11 +272,36 @@ class FlareChain(Chain): "pip install langchain-openai" ) raise ImportError(msg) from e - llm = ChatOpenAI( - max_completion_tokens=max_generation_len, - logprobs=True, - temperature=0, - ) + # Preserve supplied llm instead of always creating a new ChatOpenAI. + # Enforce ChatOpenAI requirement (token logprobs needed for FLARE). + if llm is None: + llm = ChatOpenAI( + max_completion_tokens=max_generation_len, + logprobs=True, + temperature=0, + ) + else: + if not isinstance(llm, ChatOpenAI): + msg = ( + f"FlareChain.from_llm requires ChatOpenAI; got " + f"{type(llm).__name__}." + ) + raise TypeError(msg) + if not getattr(llm, "logprobs", False): # attribute presence may vary + msg = ( + "Provided ChatOpenAI instance must be constructed with " + "logprobs=True for FlareChain." + ) + raise ValueError(msg) + current_max = getattr(llm, "max_completion_tokens", None) + if current_max is not None and current_max != max_generation_len: + logger.debug( + "FlareChain.from_llm: supplied llm max_completion_tokens=%s " + "differs from requested max_generation_len=%s; " + "leaving model unchanged.", + current_max, + max_generation_len, + ) response_chain = PROMPT | llm question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser() return cls( diff --git a/libs/langchain/tests/unit_tests/chains/test_flare.py b/libs/langchain/tests/unit_tests/chains/test_flare.py new file mode 100644 index 00000000000..5d7c679de2f --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_flare.py @@ -0,0 +1,51 @@ +"""Tests for FlareChain.from_llm preserving supplied ChatOpenAI instance.""" + +from typing import cast + +import pytest +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import RunnableSequence + +from langchain.chains.flare.base import FlareChain + + +class _EmptyRetriever(BaseRetriever): + """Minimal no-op retriever used only for constructing FlareChain in tests.""" + + def _get_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override] + del query # mark used + return [] + + async def _aget_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override] + del query # mark used + return [] + + +def test_from_llm_rejects_non_chatopenai() -> None: + class Dummy: + pass + + with pytest.raises(TypeError): + FlareChain.from_llm(Dummy()) # type: ignore[arg-type] + + +@pytest.mark.requires("langchain_openai") +def test_from_llm_uses_supplied_chatopenai(monkeypatch: pytest.MonkeyPatch) -> None: + try: + from langchain_openai import ChatOpenAI + except ImportError: # pragma: no cover + pytest.skip("langchain-openai not installed") + + # Provide dummy API key to satisfy constructor env validation. + monkeypatch.setenv("OPENAI_API_KEY", "TEST") + + supplied = ChatOpenAI(temperature=0.51, logprobs=True, max_completion_tokens=21) + chain = FlareChain.from_llm( + supplied, + max_generation_len=32, + retriever=_EmptyRetriever(), + ) + + llm_in_chain = cast("RunnableSequence", chain.question_generator_chain).steps[1] + assert llm_in_chain is supplied