fix(langchain-classic): align arank_fusion string normalization with rank_fusion in EnsembleRetriever (#38051)

Closes #37736

---

`EnsembleRetriever` normalizes retriever outputs to `Document` objects
in both `rank_fusion` (sync) and `arank_fusion` (async), but the two
methods used different conditions:

- `rank_fusion` wraps only bare strings: `isinstance(doc, str)`
- `arank_fusion` wrapped anything that isn't a `Document`: `not
isinstance(doc, Document)`

If a retriever returns a non-string, non-`Document` value through the
async path, `arank_fusion` would try to construct
`Document(page_content=<non-string>)` and Pydantic raises a
`ValidationError`. The sync path handles the same input without crashing
— the behavior is inconsistent.

The fix is a one-line change in `arank_fusion` to use `isinstance(doc,
str)`, matching the sync path exactly.

Three tests were added to `test_ensemble.py`:

- `test_rank_fusion_bare_strings` — sync path wraps bare strings into
Documents
- `test_arank_fusion_bare_strings` — async path wraps bare strings into
Documents
- `test_arank_fusion_matches_rank_fusion` — sync and async return
identical results for normal Document input

---

This continues the work from #37737 by @AliMuhammadAslam (credited as
co-author), rebased onto `master` with the type-check lint failure
resolved. Supersedes that PR.

Co-authored-by: AliMuhammadAslam <aaalimohdaslam@gmail.com>
This commit is contained in:
Mason Daugherty
2026-06-10 21:33:13 -04:00
committed by GitHub
parent 6b9e22dbbc
commit d74e537dac
2 changed files with 64 additions and 2 deletions

View File

@@ -294,7 +294,7 @@ class EnsembleRetriever(BaseRetriever):
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=doc) if not isinstance(doc, Document) else doc
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable]
for doc in retriever_docs[i]
]

View File

@@ -1,4 +1,7 @@
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
@@ -20,6 +23,30 @@ class MockRetriever(BaseRetriever):
return self.docs
class BareStringRetriever(BaseRetriever):
"""Retriever that returns bare strings instead of Documents."""
strings: list[str]
@override
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun | None = None,
) -> list:
return list(self.strings)
@override
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
) -> list:
return list(self.strings)
def test_invoke() -> None:
documents1 = [
Document(page_content="a", metadata={"id": 1}),
@@ -92,3 +119,38 @@ def test_invoke() -> None:
# Additionally, the document with page_content "b" will be ranked 1st.
assert len(ranked_documents) == 3
assert ranked_documents[0].page_content == "b"
def test_rank_fusion_bare_strings() -> None:
"""Bare strings returned by a retriever should be wrapped into Documents."""
retriever = BareStringRetriever(strings=["foo", "bar"])
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
results = ensemble.invoke("_")
assert all(isinstance(doc, Document) for doc in results)
assert {doc.page_content for doc in results} == {"foo", "bar"}
async def test_arank_fusion_bare_strings() -> None:
"""arank_fusion should wrap bare strings the same way rank_fusion does."""
retriever = BareStringRetriever(strings=["foo", "bar"])
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
results = await ensemble.ainvoke("_")
assert all(isinstance(doc, Document) for doc in results)
assert {doc.page_content for doc in results} == {"foo", "bar"}
async def test_arank_fusion_matches_rank_fusion() -> None:
"""Sync and async rank fusion should produce identical results."""
docs = [
Document(page_content="alpha", metadata={"id": 1}),
Document(page_content="beta", metadata={"id": 2}),
]
retriever = MockRetriever(docs=docs)
ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0])
sync_results = ensemble.invoke("_")
async_results = await ensemble.ainvoke("_")
assert [d.page_content for d in sync_results] == [
d.page_content for d in async_results
]