From 438beb6c94af387028d2109865febe7f02b39846 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 11 Jan 2024 16:22:17 -0800 Subject: [PATCH] Pass config specs through ensemble retriever (#15917) --------- Co-authored-by: Harrison Chase --- .../data_connection/retrievers/ensemble.ipynb | 95 +++++++++++-- .../langchain/retrievers/ensemble.py | 125 ++++++++++++++++-- 2 files changed, 197 insertions(+), 23 deletions(-) diff --git a/docs/docs/modules/data_connection/retrievers/ensemble.ipynb b/docs/docs/modules/data_connection/retrievers/ensemble.ipynb index 5812ccc9ee5..961819b9561 100644 --- a/docs/docs/modules/data_connection/retrievers/ensemble.ipynb +++ b/docs/docs/modules/data_connection/retrievers/ensemble.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -35,22 +35,31 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "doc_list = [\n", + "doc_list_1 = [\n", " \"I like apples\",\n", " \"I like oranges\",\n", " \"Apples and oranges are fruits\",\n", "]\n", "\n", "# initialize the bm25 retriever and faiss retriever\n", - "bm25_retriever = BM25Retriever.from_texts(doc_list)\n", + "bm25_retriever = BM25Retriever.from_texts(\n", + " doc_list_1, metadatas=[{\"source\": 1}] * len(doc_list_1)\n", + ")\n", "bm25_retriever.k = 2\n", "\n", + "doc_list_2 = [\n", + " \"You like apples\",\n", + " \"You like oranges\",\n", + "]\n", + "\n", "embedding = OpenAIEmbeddings()\n", - "faiss_vectorstore = FAISS.from_texts(doc_list, embedding)\n", + "faiss_vectorstore = FAISS.from_texts(\n", + " doc_list_2, embedding, metadatas=[{\"source\": 2}] * len(doc_list_2)\n", + ")\n", "faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n", "\n", "# initialize the ensemble retriever\n", @@ -61,26 +70,92 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[Document(page_content='I like apples'),\n", - " Document(page_content='Apples and oranges are fruits')]" + "[Document(page_content='You like apples', metadata={'source': 2}),\n", + " Document(page_content='I like apples', metadata={'source': 1}),\n", + " Document(page_content='You like oranges', metadata={'source': 2}),\n", + " Document(page_content='Apples and oranges are fruits', metadata={'source': 1})]" ] }, - "execution_count": 7, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "docs = ensemble_retriever.get_relevant_documents(\"apples\")\n", + "docs = ensemble_retriever.invoke(\"apples\")\n", "docs" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Runtime Configuration\n", + "\n", + "We can also configure the retrievers at runtime. In order to do this, we need to mark the fields as configurable" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.runnables import ConfigurableField" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "faiss_retriever = faiss_vectorstore.as_retriever(\n", + " search_kwargs={\"k\": 2}\n", + ").configurable_fields(\n", + " search_kwargs=ConfigurableField(\n", + " id=\"search_kwargs_faiss\",\n", + " name=\"Search Kwargs\",\n", + " description=\"The search kwargs to use\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_retriever = EnsembleRetriever(\n", + " retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "config = {\"configurable\": {\"search_kwargs_faiss\": {\"k\": 1}}}\n", + "docs = ensemble_retriever.invoke(\"apples\", config=config)\n", + "docs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that this only returns one source from the FAISS retriever, because we pass in the relevant configuration at run time" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 7784775fe24..324ede8ae27 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -2,11 +2,19 @@ Ensemble retriever that ensemble the results of multiple retrievers by using weighted Reciprocal Rank Fusion """ -from typing import Any, Dict, List +import asyncio +from typing import Any, Dict, List, Optional from langchain_core.documents import Document +from langchain_core.load.dump import dumpd from langchain_core.pydantic_v1 import root_validator -from langchain_core.retrievers import BaseRetriever +from langchain_core.retrievers import BaseRetriever, RetrieverLike +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.config import ensure_config, patch_config +from langchain_core.runnables.utils import ( + ConfigurableFieldSpec, + get_unique_config_specs, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -28,10 +36,17 @@ class EnsembleRetriever(BaseRetriever): Default is 60. """ - retrievers: List[BaseRetriever] + retrievers: List[RetrieverLike] weights: List[float] c: int = 60 + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + """List configurable fields for this runnable.""" + return get_unique_config_specs( + spec for retriever in self.retrievers for spec in retriever.config_specs + ) + @root_validator(pre=True) def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]: if not values.get("weights"): @@ -39,6 +54,74 @@ class EnsembleRetriever(BaseRetriever): values["weights"] = [1 / n_retrievers] * n_retrievers return values + def invoke( + self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> List[Document]: + from langchain_core.callbacks.manager import CallbackManager + + config = ensure_config(config) + callback_manager = CallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags", []), + local_tags=self.tags, + inheritable_metadata=config.get("metadata", {}), + local_metadata=self.metadata, + ) + run_manager = callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + **kwargs, + ) + try: + result = self.rank_fusion(input, run_manager=run_manager, config=config) + except Exception as e: + run_manager.on_retriever_error(e) + raise e + else: + run_manager.on_retriever_end( + result, + **kwargs, + ) + return result + + async def ainvoke( + self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> List[Document]: + from langchain_core.callbacks.manager import AsyncCallbackManager + + config = ensure_config(config) + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=config.get("tags", []), + local_tags=self.tags, + inheritable_metadata=config.get("metadata", {}), + local_metadata=self.metadata, + ) + run_manager = await callback_manager.on_retriever_start( + dumpd(self), + input, + name=config.get("run_name"), + **kwargs, + ) + try: + result = await self.arank_fusion( + input, run_manager=run_manager, config=config + ) + except Exception as e: + await run_manager.on_retriever_error(e) + raise e + else: + await run_manager.on_retriever_end( + result, + **kwargs, + ) + return result + def _get_relevant_documents( self, query: str, @@ -82,7 +165,11 @@ class EnsembleRetriever(BaseRetriever): return fused_documents def rank_fusion( - self, query: str, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + run_manager: CallbackManagerForRetrieverRun, + *, + config: Optional[RunnableConfig] = None, ) -> List[Document]: """ Retrieve the results of the retrievers and use rank_fusion_func to get @@ -97,8 +184,11 @@ class EnsembleRetriever(BaseRetriever): # Get the results of all retrievers. retriever_docs = [ - retriever.get_relevant_documents( - query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") + retriever.invoke( + query, + patch_config( + config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") + ), ) for i, retriever in enumerate(self.retrievers) ] @@ -116,7 +206,11 @@ class EnsembleRetriever(BaseRetriever): return fused_documents async def arank_fusion( - self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun + self, + query: str, + run_manager: AsyncCallbackManagerForRetrieverRun, + *, + config: Optional[RunnableConfig] = None, ) -> List[Document]: """ Asynchronously retrieve the results of the retrievers @@ -130,12 +224,17 @@ class EnsembleRetriever(BaseRetriever): """ # Get the results of all retrievers. - retriever_docs = [ - await retriever.aget_relevant_documents( - query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") - ) - for i, retriever in enumerate(self.retrievers) - ] + retriever_docs = await asyncio.gather( + *[ + retriever.ainvoke( + query, + patch_config( + config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") + ), + ) + for i, retriever in enumerate(self.retrievers) + ] + ) # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)):