mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
fix(langchain-classic): align arank_fusion string normalization with rank_fusion in EnsembleRetriever (#38051)
Closes #37736 --- `EnsembleRetriever` normalizes retriever outputs to `Document` objects in both `rank_fusion` (sync) and `arank_fusion` (async), but the two methods used different conditions: - `rank_fusion` wraps only bare strings: `isinstance(doc, str)` - `arank_fusion` wrapped anything that isn't a `Document`: `not isinstance(doc, Document)` If a retriever returns a non-string, non-`Document` value through the async path, `arank_fusion` would try to construct `Document(page_content=<non-string>)` and Pydantic raises a `ValidationError`. The sync path handles the same input without crashing — the behavior is inconsistent. The fix is a one-line change in `arank_fusion` to use `isinstance(doc, str)`, matching the sync path exactly. Three tests were added to `test_ensemble.py`: - `test_rank_fusion_bare_strings` — sync path wraps bare strings into Documents - `test_arank_fusion_bare_strings` — async path wraps bare strings into Documents - `test_arank_fusion_matches_rank_fusion` — sync and async return identical results for normal Document input --- This continues the work from #37737 by @AliMuhammadAslam (credited as co-author), rebased onto `master` with the type-check lint failure resolved. Supersedes that PR. Co-authored-by: AliMuhammadAslam <aaalimohdaslam@gmail.com>
This commit is contained in:
@@ -294,7 +294,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
retriever_docs[i] = [
|
||||
Document(page_content=doc) if not isinstance(doc, Document) else doc
|
||||
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable]
|
||||
for doc in retriever_docs[i]
|
||||
]
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing_extensions import override
|
||||
@@ -20,6 +23,30 @@ class MockRetriever(BaseRetriever):
|
||||
return self.docs
|
||||
|
||||
|
||||
class BareStringRetriever(BaseRetriever):
|
||||
"""Retriever that returns bare strings instead of Documents."""
|
||||
|
||||
strings: list[str]
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun | None = None,
|
||||
) -> list:
|
||||
return list(self.strings)
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
|
||||
) -> list:
|
||||
return list(self.strings)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
documents1 = [
|
||||
Document(page_content="a", metadata={"id": 1}),
|
||||
@@ -92,3 +119,38 @@ def test_invoke() -> None:
|
||||
# Additionally, the document with page_content "b" will be ranked 1st.
|
||||
assert len(ranked_documents) == 3
|
||||
assert ranked_documents[0].page_content == "b"
|
||||
|
||||
|
||||
def test_rank_fusion_bare_strings() -> None:
|
||||
"""Bare strings returned by a retriever should be wrapped into Documents."""
|
||||
retriever = BareStringRetriever(strings=["foo", "bar"])
|
||||
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
|
||||
results = ensemble.invoke("_")
|
||||
assert all(isinstance(doc, Document) for doc in results)
|
||||
assert {doc.page_content for doc in results} == {"foo", "bar"}
|
||||
|
||||
|
||||
async def test_arank_fusion_bare_strings() -> None:
|
||||
"""arank_fusion should wrap bare strings the same way rank_fusion does."""
|
||||
retriever = BareStringRetriever(strings=["foo", "bar"])
|
||||
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
|
||||
results = await ensemble.ainvoke("_")
|
||||
assert all(isinstance(doc, Document) for doc in results)
|
||||
assert {doc.page_content for doc in results} == {"foo", "bar"}
|
||||
|
||||
|
||||
async def test_arank_fusion_matches_rank_fusion() -> None:
|
||||
"""Sync and async rank fusion should produce identical results."""
|
||||
docs = [
|
||||
Document(page_content="alpha", metadata={"id": 1}),
|
||||
Document(page_content="beta", metadata={"id": 2}),
|
||||
]
|
||||
retriever = MockRetriever(docs=docs)
|
||||
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
|
||||
|
||||
sync_results = ensemble.invoke("_")
|
||||
async_results = await ensemble.ainvoke("_")
|
||||
|
||||
assert [d.page_content for d in sync_results] == [
|
||||
d.page_content for d in async_results
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user