Pass config specs through ensemble retriever (#15917)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2024-01-11 16:22:17 -08:00 committed by GitHub
parent ebb6ad4f7a
commit 438beb6c94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 197 additions and 23 deletions

View File

@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -35,22 +35,31 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"doc_list = [\n",
"doc_list_1 = [\n",
" \"I like apples\",\n",
" \"I like oranges\",\n",
" \"Apples and oranges are fruits\",\n",
"]\n",
"\n",
"# initialize the bm25 retriever and faiss retriever\n",
"bm25_retriever = BM25Retriever.from_texts(doc_list)\n",
"bm25_retriever = BM25Retriever.from_texts(\n",
" doc_list_1, metadatas=[{\"source\": 1}] * len(doc_list_1)\n",
")\n",
"bm25_retriever.k = 2\n",
"\n",
"doc_list_2 = [\n",
" \"You like apples\",\n",
" \"You like oranges\",\n",
"]\n",
"\n",
"embedding = OpenAIEmbeddings()\n",
"faiss_vectorstore = FAISS.from_texts(doc_list, embedding)\n",
"faiss_vectorstore = FAISS.from_texts(\n",
" doc_list_2, embedding, metadatas=[{\"source\": 2}] * len(doc_list_2)\n",
")\n",
"faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n",
"\n",
"# initialize the ensemble retriever\n",
@ -61,26 +70,92 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='I like apples'),\n",
" Document(page_content='Apples and oranges are fruits')]"
"[Document(page_content='You like apples', metadata={'source': 2}),\n",
" Document(page_content='I like apples', metadata={'source': 1}),\n",
" Document(page_content='You like oranges', metadata={'source': 2}),\n",
" Document(page_content='Apples and oranges are fruits', metadata={'source': 1})]"
]
},
"execution_count": 7,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs = ensemble_retriever.get_relevant_documents(\"apples\")\n",
"docs = ensemble_retriever.invoke(\"apples\")\n",
"docs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runtime Configuration\n",
"\n",
"We can also configure the retrievers at runtime. In order to do this, we need to mark the fields as configurable"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import ConfigurableField"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"faiss_retriever = faiss_vectorstore.as_retriever(\n",
" search_kwargs={\"k\": 2}\n",
").configurable_fields(\n",
" search_kwargs=ConfigurableField(\n",
" id=\"search_kwargs_faiss\",\n",
" name=\"Search Kwargs\",\n",
" description=\"The search kwargs to use\",\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"ensemble_retriever = EnsembleRetriever(\n",
" retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"config = {\"configurable\": {\"search_kwargs_faiss\": {\"k\": 1}}}\n",
"docs = ensemble_retriever.invoke(\"apples\", config=config)\n",
"docs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that this only returns one source from the FAISS retriever, because we pass in the relevant configuration at run time"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@ -2,11 +2,19 @@
Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion
"""
from typing import Any, Dict, List
import asyncio
from typing import Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import ensure_config, patch_config
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -28,10 +36,17 @@ class EnsembleRetriever(BaseRetriever):
Default is 60.
"""
retrievers: List[BaseRetriever]
retrievers: List[RetrieverLike]
weights: List[float]
c: int = 60
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""List configurable fields for this runnable."""
return get_unique_config_specs(
spec for retriever in self.retrievers for spec in retriever.config_specs
)
@root_validator(pre=True)
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if not values.get("weights"):
@ -39,6 +54,74 @@ class EnsembleRetriever(BaseRetriever):
values["weights"] = [1 / n_retrievers] * n_retrievers
return values
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
from langchain_core.callbacks.manager import CallbackManager
config = ensure_config(config)
callback_manager = CallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
input,
name=config.get("run_name"),
**kwargs,
)
try:
result = self.rank_fusion(input, run_manager=run_manager, config=config)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
from langchain_core.callbacks.manager import AsyncCallbackManager
config = ensure_config(config)
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
input,
name=config.get("run_name"),
**kwargs,
)
try:
result = await self.arank_fusion(
input, run_manager=run_manager, config=config
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result
def _get_relevant_documents(
self,
query: str,
@ -82,7 +165,11 @@ class EnsembleRetriever(BaseRetriever):
return fused_documents
def rank_fusion(
self, query: str, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""
Retrieve the results of the retrievers and use rank_fusion_func to get
@ -97,8 +184,11 @@ class EnsembleRetriever(BaseRetriever):
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
retriever.invoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
),
)
for i, retriever in enumerate(self.retrievers)
]
@ -116,7 +206,11 @@ class EnsembleRetriever(BaseRetriever):
return fused_documents
async def arank_fusion(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""
Asynchronously retrieve the results of the retrievers
@ -130,12 +224,17 @@ class EnsembleRetriever(BaseRetriever):
"""
# Get the results of all retrievers.
retriever_docs = [
await retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
)
for i, retriever in enumerate(self.retrievers)
]
retriever_docs = await asyncio.gather(
*[
retriever.ainvoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
),
)
for i, retriever in enumerate(self.retrievers)
]
)
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):