mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
core[patch]: passthrough BaseRetriever.invoke(**kwargs) (#16551)
Fix for #16547
This commit is contained in:
parent
355ef2a4a6
commit
e510cfaa23
@ -115,7 +115,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
return self.get_relevant_documents(
|
return self.get_relevant_documents(
|
||||||
@ -124,13 +124,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: str,
|
input: str,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
return await self.aget_relevant_documents(
|
return await self.aget_relevant_documents(
|
||||||
@ -139,6 +140,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -208,7 +210,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
query,
|
query,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
_kwargs = kwargs if self._expects_other_args else {}
|
_kwargs = kwargs if self._expects_other_args else {}
|
||||||
@ -224,7 +225,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
else:
|
else:
|
||||||
run_manager.on_retriever_end(
|
run_manager.on_retriever_end(
|
||||||
result,
|
result,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -266,7 +266,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
query,
|
query,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
_kwargs = kwargs if self._expects_other_args else {}
|
_kwargs = kwargs if self._expects_other_args else {}
|
||||||
@ -282,6 +281,5 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
else:
|
else:
|
||||||
await run_manager.on_retriever_end(
|
await run_manager.on_retriever_end(
|
||||||
result,
|
result,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user