mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
add Hybrid retriever that not require any external service (#8108)
- Until now, hybrid search was limited to modules requiring external services, such as Weaviate/Pinecone Hybrid Search. However, I have developed a hybrid retriever that can merge a list of retrievers using the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm. This new approach, similar to Weaviate hybrid search, does not require the initialization of any external service. - Dependencies: No - Twitter handle: dayuanjian21687 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.docarray import DocArrayRetriever
|
||||
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
|
||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
)
|
||||
@@ -64,4 +65,5 @@ __all__ = [
|
||||
"ZepRetriever",
|
||||
"ZillizRetriever",
|
||||
"DocArrayRetriever",
|
||||
"EnsembleRetriever",
|
||||
]
|
||||
|
184
libs/langchain/langchain/retrievers/ensemble.py
Normal file
184
libs/langchain/langchain/retrievers/ensemble.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Ensemble retriever that ensemble the results of
|
||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class EnsembleRetriever(BaseRetriever):
|
||||
"""
|
||||
This class ensemble the results of multiple retrievers by using rank fusion.
|
||||
|
||||
Args:
|
||||
retrievers: A list of retrievers to ensemble.
|
||||
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
||||
weighting for all retrievers.
|
||||
c: A constant added to the rank, controlling the balance between the importance
|
||||
of high-ranked items and the consideration given to lower-ranked items.
|
||||
Default is 60.
|
||||
"""
|
||||
|
||||
retrievers: List[BaseRetriever]
|
||||
weights: List[float]
|
||||
c: int = 60
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not values.get("weights"):
|
||||
n_retrievers = len(values["retrievers"])
|
||||
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = self.rank_fusion(query, run_manager)
|
||||
|
||||
return fused_documents
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
|
||||
# Get fused result of the retrievers.
|
||||
fused_documents = await self.arank_fusion(query, run_manager)
|
||||
|
||||
return fused_documents
|
||||
|
||||
def rank_fusion(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||||
the final result.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.get_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
return fused_documents
|
||||
|
||||
async def arank_fusion(
|
||||
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously retrieve the results of the retrievers
|
||||
and use rank_fusion_func to get the final result.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
await retriever.aget_relevant_documents(
|
||||
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
return fused_documents
|
||||
|
||||
def weighted_reciprocal_rank(
|
||||
self, doc_lists: List[List[Document]]
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
||||
You can find more details about RRF here:
|
||||
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||
|
||||
Args:
|
||||
doc_lists: A list of rank lists, where each rank list contains unique items.
|
||||
|
||||
Returns:
|
||||
list: The final aggregated list of items sorted by their weighted RRF
|
||||
scores in descending order.
|
||||
"""
|
||||
if len(doc_lists) != len(self.weights):
|
||||
raise ValueError(
|
||||
"Number of rank lists must be equal to the number of weights."
|
||||
)
|
||||
|
||||
# Create a union of all unique documents in the input doc_lists
|
||||
all_documents = set()
|
||||
for doc_list in doc_lists:
|
||||
for doc in doc_list:
|
||||
all_documents.add(doc.page_content)
|
||||
|
||||
# Initialize the RRF score dictionary for each document
|
||||
rrf_score_dic = {doc: 0.0 for doc in all_documents}
|
||||
|
||||
# Calculate RRF scores for each document
|
||||
for doc_list, weight in zip(doc_lists, self.weights):
|
||||
for rank, doc in enumerate(doc_list, start=1):
|
||||
rrf_score = weight * (1 / (rank + self.c))
|
||||
rrf_score_dic[doc.page_content] += rrf_score
|
||||
|
||||
# Sort documents by their RRF scores in descending order
|
||||
sorted_documents = sorted(
|
||||
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
|
||||
)
|
||||
|
||||
# Map the sorted page_content back to the original document objects
|
||||
page_content_to_doc_map = {
|
||||
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
|
||||
}
|
||||
sorted_docs = [
|
||||
page_content_to_doc_map[page_content] for page_content in sorted_documents
|
||||
]
|
||||
|
||||
return sorted_docs
|
42
libs/langchain/tests/unit_tests/retrievers/test_ensemble.py
Normal file
42
libs/langchain/tests/unit_tests/retrievers/test_ensemble.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.bm25 import BM25Retriever
|
||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25")
|
||||
def test_ensemble_retriever_get_relevant_docs() -> None:
|
||||
doc_list = [
|
||||
"I like apples",
|
||||
"I like oranges",
|
||||
"Apples and oranges are fruits",
|
||||
]
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(doc_list)
|
||||
dummy_retriever.k = 1
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[dummy_retriever, dummy_retriever]
|
||||
)
|
||||
docs = ensemble_retriever.get_relevant_documents("I like apples")
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25")
|
||||
def test_weighted_reciprocal_rank() -> None:
|
||||
doc1 = Document(page_content="1")
|
||||
doc2 = Document(page_content="2")
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
|
||||
)
|
||||
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
||||
assert result[0].page_content == "2"
|
||||
assert result[1].page_content == "1"
|
||||
|
||||
ensemble_retriever.weights = [0.5, 0.4]
|
||||
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
||||
assert result[0].page_content == "1"
|
||||
assert result[1].page_content == "2"
|
Reference in New Issue
Block a user