core[patch]: passthrough BaseRetriever.invoke(**kwargs) (#16551)

Fix for #16547
This commit is contained in:
Bagatur 2024-01-25 08:58:39 -08:00 committed by GitHub
parent 355ef2a4a6
commit e510cfaa23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,7 +115,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
) )
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
config = ensure_config(config) config = ensure_config(config)
return self.get_relevant_documents( return self.get_relevant_documents(
@ -124,13 +124,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
**kwargs,
) )
async def ainvoke( async def ainvoke(
self, self,
input: str, input: str,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
config = ensure_config(config) config = ensure_config(config)
return await self.aget_relevant_documents( return await self.aget_relevant_documents(
@ -139,6 +140,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"), run_name=config.get("run_name"),
**kwargs,
) )
@abstractmethod @abstractmethod
@ -208,7 +210,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self), dumpd(self),
query, query,
name=run_name, name=run_name,
**kwargs,
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {} _kwargs = kwargs if self._expects_other_args else {}
@ -224,7 +225,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
else: else:
run_manager.on_retriever_end( run_manager.on_retriever_end(
result, result,
**kwargs,
) )
return result return result
@ -266,7 +266,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
dumpd(self), dumpd(self),
query, query,
name=run_name, name=run_name,
**kwargs,
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {} _kwargs = kwargs if self._expects_other_args else {}
@ -282,6 +281,5 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
else: else:
await run_manager.on_retriever_end( await run_manager.on_retriever_end(
result, result,
**kwargs,
) )
return result return result