mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
Pass config specs through ensemble retriever (#15917)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
ebb6ad4f7a
commit
438beb6c94
@ -24,7 +24,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -35,22 +35,31 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_list = [\n",
|
||||
"doc_list_1 = [\n",
|
||||
" \"I like apples\",\n",
|
||||
" \"I like oranges\",\n",
|
||||
" \"Apples and oranges are fruits\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# initialize the bm25 retriever and faiss retriever\n",
|
||||
"bm25_retriever = BM25Retriever.from_texts(doc_list)\n",
|
||||
"bm25_retriever = BM25Retriever.from_texts(\n",
|
||||
" doc_list_1, metadatas=[{\"source\": 1}] * len(doc_list_1)\n",
|
||||
")\n",
|
||||
"bm25_retriever.k = 2\n",
|
||||
"\n",
|
||||
"doc_list_2 = [\n",
|
||||
" \"You like apples\",\n",
|
||||
" \"You like oranges\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"embedding = OpenAIEmbeddings()\n",
|
||||
"faiss_vectorstore = FAISS.from_texts(doc_list, embedding)\n",
|
||||
"faiss_vectorstore = FAISS.from_texts(\n",
|
||||
" doc_list_2, embedding, metadatas=[{\"source\": 2}] * len(doc_list_2)\n",
|
||||
")\n",
|
||||
"faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n",
|
||||
"\n",
|
||||
"# initialize the ensemble retriever\n",
|
||||
@ -61,26 +70,92 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='I like apples'),\n",
|
||||
" Document(page_content='Apples and oranges are fruits')]"
|
||||
"[Document(page_content='You like apples', metadata={'source': 2}),\n",
|
||||
" Document(page_content='I like apples', metadata={'source': 1}),\n",
|
||||
" Document(page_content='You like oranges', metadata={'source': 2}),\n",
|
||||
" Document(page_content='Apples and oranges are fruits', metadata={'source': 1})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = ensemble_retriever.get_relevant_documents(\"apples\")\n",
|
||||
"docs = ensemble_retriever.invoke(\"apples\")\n",
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Runtime Configuration\n",
|
||||
"\n",
|
||||
"We can also configure the retrievers at runtime. In order to do this, we need to mark the fields as configurable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.runnables import ConfigurableField"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"faiss_retriever = faiss_vectorstore.as_retriever(\n",
|
||||
" search_kwargs={\"k\": 2}\n",
|
||||
").configurable_fields(\n",
|
||||
" search_kwargs=ConfigurableField(\n",
|
||||
" id=\"search_kwargs_faiss\",\n",
|
||||
" name=\"Search Kwargs\",\n",
|
||||
" description=\"The search kwargs to use\",\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ensemble_retriever = EnsembleRetriever(\n",
|
||||
" retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = {\"configurable\": {\"search_kwargs_faiss\": {\"k\": 1}}}\n",
|
||||
"docs = ensemble_retriever.invoke(\"apples\", config=config)\n",
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice that this only returns one source from the FAISS retriever, because we pass in the relevant configuration at run time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@ -2,11 +2,19 @@
|
||||
Ensemble retriever that ensemble the results of
|
||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -28,10 +36,17 @@ class EnsembleRetriever(BaseRetriever):
|
||||
Default is 60.
|
||||
"""
|
||||
|
||||
retrievers: List[BaseRetriever]
|
||||
retrievers: List[RetrieverLike]
|
||||
weights: List[float]
|
||||
c: int = 60
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this runnable."""
|
||||
return get_unique_config_specs(
|
||||
spec for retriever in self.retrievers for spec in retriever.config_specs
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not values.get("weights"):
|
||||
@ -39,6 +54,74 @@ class EnsembleRetriever(BaseRetriever):
|
||||
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||||
return values
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = await self.arank_fusion(
|
||||
input, run_manager=run_manager, config=config
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@ -82,7 +165,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
return fused_documents
|
||||
|
||||
def rank_fusion(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||||
@ -97,8 +184,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
retriever.invoke(
|
||||
query,
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
@ -116,7 +206,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
return fused_documents
|
||||
|
||||
async def arank_fusion(
|
||||
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously retrieve the results of the retrievers
|
||||
@ -130,12 +224,17 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
await retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
retriever_docs = await asyncio.gather(
|
||||
*[
|
||||
retriever.ainvoke(
|
||||
query,
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
)
|
||||
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
|
Loading…
Reference in New Issue
Block a user