propagate kwargs

This commit is contained in:
Chester Curme 2025-03-25 11:15:32 -04:00
parent c5e42a4027
commit e00ddec3a6

View File

@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever):
**kwargs,
)
try:
result = self.rank_fusion(input, run_manager=run_manager, config=config)
result = self.rank_fusion(
input, run_manager=run_manager, config=config, **kwargs
)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever):
)
try:
result = await self.arank_fusion(
input, run_manager=run_manager, config=config
input, run_manager=run_manager, config=config, **kwargs
)
except Exception as e:
await run_manager.on_retriever_error(e)
@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever):
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Get the relevant documents for a given query.
@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager)
fused_documents = self.rank_fusion(query, run_manager, **kwargs)
return fused_documents
@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever):
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager)
fused_documents = await self.arank_fusion(query, run_manager, **kwargs)
return fused_documents
@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: CallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve the results of the retrievers and use rank_fusion_func to get
@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever):
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
),
**kwargs,
)
for i, retriever in enumerate(self.retrievers)
]
@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever):
run_manager: AsyncCallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously retrieve the results of the retrievers
@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever):
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
**kwargs,
)
for i, retriever in enumerate(self.retrievers)
]