diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 897efbef867..c42311a4428 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -190,15 +190,40 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): retriever.invoke("query") """ + from langchain_core.callbacks.manager import CallbackManager + config = ensure_config(config) - return self.get_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, + callback_manager = CallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags"), + local_tags=self.tags, + inheritable_metadata=config.get("metadata"), + local_metadata=self.metadata, ) + run_manager = callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + run_id=kwargs.pop("run_id", None), + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = self._get_relevant_documents( + input, run_manager=run_manager, **_kwargs + ) + else: + result = self._get_relevant_documents(input, **_kwargs) + except Exception as e: + run_manager.on_retriever_error(e) + raise e + else: + run_manager.on_retriever_end( + result, + ) + return result async def ainvoke( self, @@ -224,15 +249,40 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): await retriever.ainvoke("query") """ + from langchain_core.callbacks.manager import AsyncCallbackManager + config = ensure_config(config) - return await self.aget_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags"), + local_tags=self.tags, + inheritable_metadata=config.get("metadata"), + local_metadata=self.metadata, ) + run_manager = await callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + run_id=kwargs.pop("run_id", None), + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = await self._aget_relevant_documents( + input, run_manager=run_manager, **_kwargs + ) + else: + result = await self._aget_relevant_documents(input, **_kwargs) + except Exception as e: + await run_manager.on_retriever_error(e) + raise e + else: + await run_manager.on_retriever_end( + result, + ) + return result @abstractmethod def _get_relevant_documents( @@ -293,39 +343,16 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): Returns: List of relevant documents """ - from langchain_core.callbacks.manager import CallbackManager - - callback_manager = CallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - run_id=kwargs.pop("run_id", None), - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = self._get_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = self._get_relevant_documents(query, **_kwargs) - except Exception as e: - run_manager.on_retriever_error(e) - raise e - else: - run_manager.on_retriever_end( - result, - ) - return result + config: RunnableConfig = {} + if callbacks: + config["callbacks"] = callbacks + if tags: + config["tags"] = tags + if metadata: + config["metadata"] = metadata + if run_name: + config["run_name"] = run_name + return self.invoke(query, config, **kwargs) @deprecated(since="0.1.46", alternative="ainvoke", removal="0.3.0") async def aget_relevant_documents( @@ -357,36 +384,13 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): Returns: List of relevant documents """ - from langchain_core.callbacks.manager import AsyncCallbackManager - - callback_manager = AsyncCallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = await callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - run_id=kwargs.pop("run_id", None), - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = await self._aget_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = await self._aget_relevant_documents(query, **_kwargs) - except Exception as e: - await run_manager.on_retriever_error(e) - raise e - else: - await run_manager.on_retriever_end( - result, - ) - return result + config: RunnableConfig = {} + if callbacks: + config["callbacks"] = callbacks + if tags: + config["tags"] = tags + if metadata: + config["metadata"] = metadata + if run_name: + config["run_name"] = run_name + return await self.ainvoke(query, config, **kwargs)