mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
fix(langchain): preserve supplied llm in FlareChain.from_llm
(#32847)
This commit is contained in:
@@ -249,7 +249,7 @@ class FlareChain(Chain):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
llm: Optional[BaseLanguageModel],
|
||||
max_generation_len: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> FlareChain:
|
||||
@@ -272,11 +272,36 @@ class FlareChain(Chain):
|
||||
"pip install langchain-openai"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
llm = ChatOpenAI(
|
||||
max_completion_tokens=max_generation_len,
|
||||
logprobs=True,
|
||||
temperature=0,
|
||||
)
|
||||
# Preserve supplied llm instead of always creating a new ChatOpenAI.
|
||||
# Enforce ChatOpenAI requirement (token logprobs needed for FLARE).
|
||||
if llm is None:
|
||||
llm = ChatOpenAI(
|
||||
max_completion_tokens=max_generation_len,
|
||||
logprobs=True,
|
||||
temperature=0,
|
||||
)
|
||||
else:
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
msg = (
|
||||
f"FlareChain.from_llm requires ChatOpenAI; got "
|
||||
f"{type(llm).__name__}."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if not getattr(llm, "logprobs", False): # attribute presence may vary
|
||||
msg = (
|
||||
"Provided ChatOpenAI instance must be constructed with "
|
||||
"logprobs=True for FlareChain."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
current_max = getattr(llm, "max_completion_tokens", None)
|
||||
if current_max is not None and current_max != max_generation_len:
|
||||
logger.debug(
|
||||
"FlareChain.from_llm: supplied llm max_completion_tokens=%s "
|
||||
"differs from requested max_generation_len=%s; "
|
||||
"leaving model unchanged.",
|
||||
current_max,
|
||||
max_generation_len,
|
||||
)
|
||||
response_chain = PROMPT | llm
|
||||
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||
return cls(
|
||||
|
51
libs/langchain/tests/unit_tests/chains/test_flare.py
Normal file
51
libs/langchain/tests/unit_tests/chains/test_flare.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Tests for FlareChain.from_llm preserving supplied ChatOpenAI instance."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSequence
|
||||
|
||||
from langchain.chains.flare.base import FlareChain
|
||||
|
||||
|
||||
class _EmptyRetriever(BaseRetriever):
|
||||
"""Minimal no-op retriever used only for constructing FlareChain in tests."""
|
||||
|
||||
def _get_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
|
||||
del query # mark used
|
||||
return []
|
||||
|
||||
async def _aget_relevant_documents(self, query: str) -> list[Document]: # type: ignore[override]
|
||||
del query # mark used
|
||||
return []
|
||||
|
||||
|
||||
def test_from_llm_rejects_non_chatopenai() -> None:
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
FlareChain.from_llm(Dummy()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test_from_llm_uses_supplied_chatopenai(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError: # pragma: no cover
|
||||
pytest.skip("langchain-openai not installed")
|
||||
|
||||
# Provide dummy API key to satisfy constructor env validation.
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "TEST")
|
||||
|
||||
supplied = ChatOpenAI(temperature=0.51, logprobs=True, max_completion_tokens=21)
|
||||
chain = FlareChain.from_llm(
|
||||
supplied,
|
||||
max_generation_len=32,
|
||||
retriever=_EmptyRetriever(),
|
||||
)
|
||||
|
||||
llm_in_chain = cast("RunnableSequence", chain.question_generator_chain).steps[1]
|
||||
assert llm_in_chain is supplied
|
Reference in New Issue
Block a user