mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 19:46:55 +00:00
propagate kwargs
This commit is contained in:
parent
c5e42a4027
commit
e00ddec3a6
@ -112,7 +112,9 @@ class EnsembleRetriever(BaseRetriever):
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||||
result = self.rank_fusion(
|
||||
input, run_manager=run_manager, config=config, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
@ -146,7 +148,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
)
|
||||
try:
|
||||
result = await self.arank_fusion(
|
||||
input, run_manager=run_manager, config=config
|
||||
input, run_manager=run_manager, config=config, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
@ -163,6 +165,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get the relevant documents for a given query.
|
||||
@ -175,7 +178,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = self.rank_fusion(query, run_manager)
|
||||
fused_documents = self.rank_fusion(query, run_manager, **kwargs)
|
||||
|
||||
return fused_documents
|
||||
|
||||
@ -184,6 +187,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously get the relevant documents for a given query.
|
||||
@ -196,7 +200,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = await self.arank_fusion(query, run_manager)
|
||||
fused_documents = await self.arank_fusion(query, run_manager, **kwargs)
|
||||
|
||||
return fused_documents
|
||||
|
||||
@ -206,6 +210,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||||
@ -225,6 +230,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
@ -247,6 +253,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously retrieve the results of the retrievers
|
||||
@ -268,6 +275,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user