From ceb73ad06fc1d7004db2d3af89bfa200d6f6c1a7 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 3 Jun 2024 07:34:53 -0700 Subject: [PATCH] core: In BaseRetriever make get_relevant_docs delegate to invoke (#22434) - This fixes all the tracing issues with people still using get_relevant_docs, and a change we need for 0.3 anyway Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --- libs/core/langchain_core/retrievers.py | 164 +++++++++++++------------ 1 file changed, 84 insertions(+), 80 deletions(-) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 897efbef867..c42311a4428 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -190,15 +190,40 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): retriever.invoke("query") """ + from langchain_core.callbacks.manager import CallbackManager + config = ensure_config(config) - return self.get_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, + callback_manager = CallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags"), + local_tags=self.tags, + inheritable_metadata=config.get("metadata"), + local_metadata=self.metadata, ) + run_manager = callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + run_id=kwargs.pop("run_id", None), + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = self._get_relevant_documents( + input, run_manager=run_manager, **_kwargs + ) + else: + result = self._get_relevant_documents(input, **_kwargs) + except Exception as e: + run_manager.on_retriever_error(e) + raise e + else: + run_manager.on_retriever_end( + result, + ) + return result async def ainvoke( self, @@ -224,15 +249,40 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): await retriever.ainvoke("query") """ + from langchain_core.callbacks.manager import AsyncCallbackManager + config = ensure_config(config) - return await self.aget_relevant_documents( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - **kwargs, + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags"), + local_tags=self.tags, + inheritable_metadata=config.get("metadata"), + local_metadata=self.metadata, ) + run_manager = await callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + run_id=kwargs.pop("run_id", None), + ) + try: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result = await self._aget_relevant_documents( + input, run_manager=run_manager, **_kwargs + ) + else: + result = await self._aget_relevant_documents(input, **_kwargs) + except Exception as e: + await run_manager.on_retriever_error(e) + raise e + else: + await run_manager.on_retriever_end( + result, + ) + return result @abstractmethod def _get_relevant_documents( @@ -293,39 +343,16 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): Returns: List of relevant documents """ - from langchain_core.callbacks.manager import CallbackManager - - callback_manager = CallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - run_id=kwargs.pop("run_id", None), - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = self._get_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = self._get_relevant_documents(query, **_kwargs) - except Exception as e: - run_manager.on_retriever_error(e) - raise e - else: - run_manager.on_retriever_end( - result, - ) - return result + config: RunnableConfig = {} + if callbacks: + config["callbacks"] = callbacks + if tags: + config["tags"] = tags + if metadata: + config["metadata"] = metadata + if run_name: + config["run_name"] = run_name + return self.invoke(query, config, **kwargs) @deprecated(since="0.1.46", alternative="ainvoke", removal="0.3.0") async def aget_relevant_documents( @@ -357,36 +384,13 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): Returns: List of relevant documents """ - from langchain_core.callbacks.manager import AsyncCallbackManager - - callback_manager = AsyncCallbackManager.configure( - callbacks, - None, - verbose=kwargs.get("verbose", False), - inheritable_tags=tags, - local_tags=self.tags, - inheritable_metadata=metadata, - local_metadata=self.metadata, - ) - run_manager = await callback_manager.on_retriever_start( - dumpd(self), - query, - name=run_name, - run_id=kwargs.pop("run_id", None), - ) - try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = await self._aget_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) - else: - result = await self._aget_relevant_documents(query, **_kwargs) - except Exception as e: - await run_manager.on_retriever_error(e) - raise e - else: - await run_manager.on_retriever_end( - result, - ) - return result + config: RunnableConfig = {} + if callbacks: + config["callbacks"] = callbacks + if tags: + config["tags"] = tags + if metadata: + config["metadata"] = metadata + if run_name: + config["run_name"] = run_name + return await self.ainvoke(query, config, **kwargs)