diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index c99878d8080..6fb81054484 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever): **kwargs, ) try: - result = self.rank_fusion(input, run_manager=run_manager, config=config) + result = self.rank_fusion( + input, run_manager=run_manager, config=config, **kwargs + ) except Exception as e: run_manager.on_retriever_error(e) raise e @@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever): ) try: result = await self.arank_fusion( - input, run_manager=run_manager, config=config + input, run_manager=run_manager, config=config, **kwargs ) except Exception as e: await run_manager.on_retriever_error(e) @@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever): query: str, *, run_manager: CallbackManagerForRetrieverRun, + **kwargs: Any, ) -> List[Document]: """ Get the relevant documents for a given query. @@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever): """ # Get fused result of the retrievers. - fused_documents = self.rank_fusion(query, run_manager) + fused_documents = self.rank_fusion(query, run_manager, **kwargs) return fused_documents @@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever): query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, ) -> List[Document]: """ Asynchronously get the relevant documents for a given query. @@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever): """ # Get fused result of the retrievers. - fused_documents = await self.arank_fusion(query, run_manager) + fused_documents = await self.arank_fusion(query, run_manager, **kwargs) return fused_documents @@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever): run_manager: CallbackManagerForRetrieverRun, *, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> List[Document]: """ Retrieve the results of the retrievers and use rank_fusion_func to get @@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever): patch_config( config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}") ), + **kwargs, ) for i, retriever in enumerate(self.retrievers) ] @@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever): run_manager: AsyncCallbackManagerForRetrieverRun, *, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> List[Document]: """ Asynchronously retrieve the results of the retrievers @@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever): config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"), ), + **kwargs, ) for i, retriever in enumerate(self.retrievers) ]