From e00ddec3a64298ac4f3cbf294fc028734b7cfdef Mon Sep 17 00:00:00 2001
From: Chester Curme <chester.curme@gmail.com>
Date: Tue, 25 Mar 2025 11:15:32 -0400
Subject: [PATCH] propagate kwargs

---
 libs/langchain/langchain/retrievers/ensemble.py | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py
index c99878d8080..6fb81054484 100644
--- a/libs/langchain/langchain/retrievers/ensemble.py
+++ b/libs/langchain/langchain/retrievers/ensemble.py
@@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever):
             **kwargs,
         )
         try:
-            result = self.rank_fusion(input, run_manager=run_manager, config=config)
+            result = self.rank_fusion(
+                input, run_manager=run_manager, config=config, **kwargs
+            )
         except Exception as e:
             run_manager.on_retriever_error(e)
             raise e
@@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever):
         )
         try:
             result = await self.arank_fusion(
-                input, run_manager=run_manager, config=config
+                input, run_manager=run_manager, config=config, **kwargs
             )
         except Exception as e:
             await run_manager.on_retriever_error(e)
@@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever):
         query: str,
         *,
         run_manager: CallbackManagerForRetrieverRun,
+        **kwargs: Any,
     ) -> List[Document]:
         """
         Get the relevant documents for a given query.
@@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever):
         """
 
         # Get fused result of the retrievers.
-        fused_documents = self.rank_fusion(query, run_manager)
+        fused_documents = self.rank_fusion(query, run_manager, **kwargs)
 
         return fused_documents
 
@@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever):
         query: str,
         *,
         run_manager: AsyncCallbackManagerForRetrieverRun,
+        **kwargs: Any,
     ) -> List[Document]:
         """
         Asynchronously get the relevant documents for a given query.
@@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever):
         """
 
         # Get fused result of the retrievers.
-        fused_documents = await self.arank_fusion(query, run_manager)
+        fused_documents = await self.arank_fusion(query, run_manager, **kwargs)
 
         return fused_documents
 
@@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever):
         run_manager: CallbackManagerForRetrieverRun,
         *,
         config: Optional[RunnableConfig] = None,
+        **kwargs: Any,
     ) -> List[Document]:
         """
         Retrieve the results of the retrievers and use rank_fusion_func to get
@@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever):
                 patch_config(
                     config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
                 ),
+                **kwargs,
             )
             for i, retriever in enumerate(self.retrievers)
         ]
@@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever):
         run_manager: AsyncCallbackManagerForRetrieverRun,
         *,
         config: Optional[RunnableConfig] = None,
+        **kwargs: Any,
     ) -> List[Document]:
         """
         Asynchronously retrieve the results of the retrievers
@@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever):
                         config,
                         callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
                     ),
+                    **kwargs,
                 )
                 for i, retriever in enumerate(self.retrievers)
             ]