From d74e537dacc344c4a157951cfe906df6972caa9e Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 10 Jun 2026 21:33:13 -0400 Subject: [PATCH] fix(langchain-classic): align `arank_fusion` string normalization with `rank_fusion` in `EnsembleRetriever` (#38051) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #37736 --- `EnsembleRetriever` normalizes retriever outputs to `Document` objects in both `rank_fusion` (sync) and `arank_fusion` (async), but the two methods used different conditions: - `rank_fusion` wraps only bare strings: `isinstance(doc, str)` - `arank_fusion` wrapped anything that isn't a `Document`: `not isinstance(doc, Document)` If a retriever returns a non-string, non-`Document` value through the async path, `arank_fusion` would try to construct `Document(page_content=)` and Pydantic raises a `ValidationError`. The sync path handles the same input without crashing — the behavior is inconsistent. The fix is a one-line change in `arank_fusion` to use `isinstance(doc, str)`, matching the sync path exactly. Three tests were added to `test_ensemble.py`: - `test_rank_fusion_bare_strings` — sync path wraps bare strings into Documents - `test_arank_fusion_bare_strings` — async path wraps bare strings into Documents - `test_arank_fusion_matches_rank_fusion` — sync and async return identical results for normal Document input --- This continues the work from #37737 by @AliMuhammadAslam (credited as co-author), rebased onto `master` with the type-check lint failure resolved. Supersedes that PR. Co-authored-by: AliMuhammadAslam --- .../langchain_classic/retrievers/ensemble.py | 2 +- .../unit_tests/retrievers/test_ensemble.py | 64 ++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain_classic/retrievers/ensemble.py b/libs/langchain/langchain_classic/retrievers/ensemble.py index 1306708b481..dcbdf593df8 100644 --- a/libs/langchain/langchain_classic/retrievers/ensemble.py +++ b/libs/langchain/langchain_classic/retrievers/ensemble.py @@ -294,7 +294,7 @@ class EnsembleRetriever(BaseRetriever): # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)): retriever_docs[i] = [ - Document(page_content=doc) if not isinstance(doc, Document) else doc + Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable] for doc in retriever_docs[i] ] diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index 1826b379e7e..e8846e678d5 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,4 +1,7 @@ -from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from typing_extensions import override @@ -20,6 +23,30 @@ class MockRetriever(BaseRetriever): return self.docs +class BareStringRetriever(BaseRetriever): + """Retriever that returns bare strings instead of Documents.""" + + strings: list[str] + + @override + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun | None = None, + ) -> list: + return list(self.strings) + + @override + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun | None = None, + ) -> list: + return list(self.strings) + + def test_invoke() -> None: documents1 = [ Document(page_content="a", metadata={"id": 1}), @@ -92,3 +119,38 @@ def test_invoke() -> None: # Additionally, the document with page_content "b" will be ranked 1st. assert len(ranked_documents) == 3 assert ranked_documents[0].page_content == "b" + + +def test_rank_fusion_bare_strings() -> None: + """Bare strings returned by a retriever should be wrapped into Documents.""" + retriever = BareStringRetriever(strings=["foo", "bar"]) + ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0]) + results = ensemble.invoke("_") + assert all(isinstance(doc, Document) for doc in results) + assert {doc.page_content for doc in results} == {"foo", "bar"} + + +async def test_arank_fusion_bare_strings() -> None: + """arank_fusion should wrap bare strings the same way rank_fusion does.""" + retriever = BareStringRetriever(strings=["foo", "bar"]) + ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0]) + results = await ensemble.ainvoke("_") + assert all(isinstance(doc, Document) for doc in results) + assert {doc.page_content for doc in results} == {"foo", "bar"} + + +async def test_arank_fusion_matches_rank_fusion() -> None: + """Sync and async rank fusion should produce identical results.""" + docs = [ + Document(page_content="alpha", metadata={"id": 1}), + Document(page_content="beta", metadata={"id": 2}), + ] + retriever = MockRetriever(docs=docs) + ensemble = EnsembleRetriever(retrievers=[retriever], weights=[1.0]) + + sync_results = ensemble.invoke("_") + async_results = await ensemble.ainvoke("_") + + assert [d.page_content for d in sync_results] == [ + d.page_content for d in async_results + ]