mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
Pass config specs through ensemble retriever (#15917)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -2,11 +2,19 @@
|
||||
Ensemble retriever that ensemble the results of
|
||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@@ -28,10 +36,17 @@ class EnsembleRetriever(BaseRetriever):
|
||||
Default is 60.
|
||||
"""
|
||||
|
||||
retrievers: List[BaseRetriever]
|
||||
retrievers: List[RetrieverLike]
|
||||
weights: List[float]
|
||||
c: int = 60
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this runnable."""
|
||||
return get_unique_config_specs(
|
||||
spec for retriever in self.retrievers for spec in retriever.config_specs
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not values.get("weights"):
|
||||
@@ -39,6 +54,74 @@ class EnsembleRetriever(BaseRetriever):
|
||||
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||||
return values
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = await self.arank_fusion(
|
||||
input, run_manager=run_manager, config=config
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@@ -82,7 +165,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
return fused_documents
|
||||
|
||||
def rank_fusion(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||||
@@ -97,8 +184,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
retriever.invoke(
|
||||
query,
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
@@ -116,7 +206,11 @@ class EnsembleRetriever(BaseRetriever):
|
||||
return fused_documents
|
||||
|
||||
async def arank_fusion(
|
||||
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously retrieve the results of the retrievers
|
||||
@@ -130,12 +224,17 @@ class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
await retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
retriever_docs = await asyncio.gather(
|
||||
*[
|
||||
retriever.ainvoke(
|
||||
query,
|
||||
patch_config(
|
||||
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
)
|
||||
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
|
Reference in New Issue
Block a user