propagate kwargs

This commit is contained in:
Chester Curme 2025-03-25 11:15:32 -04:00
parent c5e42a4027
commit e00ddec3a6

View File

@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever):
**kwargs, **kwargs,
) )
try: try:
result = self.rank_fusion(input, run_manager=run_manager, config=config) result = self.rank_fusion(
input, run_manager=run_manager, config=config, **kwargs
)
except Exception as e: except Exception as e:
run_manager.on_retriever_error(e) run_manager.on_retriever_error(e)
raise e raise e
@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever):
) )
try: try:
result = await self.arank_fusion( result = await self.arank_fusion(
input, run_manager=run_manager, config=config input, run_manager=run_manager, config=config, **kwargs
) )
except Exception as e: except Exception as e:
await run_manager.on_retriever_error(e) await run_manager.on_retriever_error(e)
@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Get the relevant documents for a given query. Get the relevant documents for a given query.
@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever):
""" """
# Get fused result of the retrievers. # Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager) fused_documents = self.rank_fusion(query, run_manager, **kwargs)
return fused_documents return fused_documents
@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Asynchronously get the relevant documents for a given query. Asynchronously get the relevant documents for a given query.
@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever):
""" """
# Get fused result of the retrievers. # Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager) fused_documents = await self.arank_fusion(query, run_manager, **kwargs)
return fused_documents return fused_documents
@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
*, *,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Retrieve the results of the retrievers and use rank_fusion_func to get Retrieve the results of the retrievers and use rank_fusion_func to get
@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever):
patch_config( patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}") config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
), ),
**kwargs,
) )
for i, retriever in enumerate(self.retrievers) for i, retriever in enumerate(self.retrievers)
] ]
@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
*, *,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Asynchronously retrieve the results of the retrievers Asynchronously retrieve the results of the retrievers
@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever):
config, config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"), callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
), ),
**kwargs,
) )
for i, retriever in enumerate(self.retrievers) for i, retriever in enumerate(self.retrievers)
] ]