mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 00:17:47 +00:00
langchain[patch]: Simplify ensemble retriever (#14427)
- **Description:** code simplification to improve readability and remove unnecessary memory allocations. - **Tag maintainer**: @baskaryan, @eyurtsev, @hwchase17. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
b36f4147b0
commit
f5d4ce840f
@ -1,9 +1,22 @@
|
|||||||
"""
|
"""
|
||||||
Ensemble retriever that ensemble the results of
|
Ensemble retriever that ensemble the results of
|
||||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from collections import defaultdict
|
||||||
|
from collections.abc import Hashable
|
||||||
|
from itertools import chain
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -20,6 +33,17 @@ from langchain_core.runnables.utils import (
|
|||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
H = TypeVar("H", bound=Hashable)
|
||||||
|
|
||||||
|
|
||||||
|
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
|
||||||
|
seen = set()
|
||||||
|
for e in iterable:
|
||||||
|
if (k := key(e)) not in seen:
|
||||||
|
seen.add(k)
|
||||||
|
yield e
|
||||||
|
|
||||||
|
|
||||||
class EnsembleRetriever(BaseRetriever):
|
class EnsembleRetriever(BaseRetriever):
|
||||||
"""Retriever that ensembles the multiple retrievers.
|
"""Retriever that ensembles the multiple retrievers.
|
||||||
@ -267,32 +291,18 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
"Number of rank lists must be equal to the number of weights."
|
"Number of rank lists must be equal to the number of weights."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a union of all unique documents in the input doc_lists
|
# Associate each doc's content with its RRF score for later sorting by it
|
||||||
all_documents = set()
|
# Duplicated contents across retrievers are collapsed & scored cumulatively
|
||||||
for doc_list in doc_lists:
|
rrf_score: Dict[str, float] = defaultdict(float)
|
||||||
for doc in doc_list:
|
|
||||||
all_documents.add(doc.page_content)
|
|
||||||
|
|
||||||
# Initialize the RRF score dictionary for each document
|
|
||||||
rrf_score_dic = {doc: 0.0 for doc in all_documents}
|
|
||||||
|
|
||||||
# Calculate RRF scores for each document
|
|
||||||
for doc_list, weight in zip(doc_lists, self.weights):
|
for doc_list, weight in zip(doc_lists, self.weights):
|
||||||
for rank, doc in enumerate(doc_list, start=1):
|
for rank, doc in enumerate(doc_list, start=1):
|
||||||
rrf_score = weight * (1 / (rank + self.c))
|
rrf_score[doc.page_content] += weight / (rank + self.c)
|
||||||
rrf_score_dic[doc.page_content] += rrf_score
|
|
||||||
|
|
||||||
# Sort documents by their RRF scores in descending order
|
# Docs are deduplicated by their contents then sorted by their scores
|
||||||
sorted_documents = sorted(
|
all_docs = chain.from_iterable(doc_lists)
|
||||||
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
|
sorted_docs = sorted(
|
||||||
|
unique_by_key(all_docs, lambda doc: doc.page_content),
|
||||||
|
reverse=True,
|
||||||
|
key=lambda doc: rrf_score[doc.page_content],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map the sorted page_content back to the original document objects
|
|
||||||
page_content_to_doc_map = {
|
|
||||||
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
|
|
||||||
}
|
|
||||||
sorted_docs = [
|
|
||||||
page_content_to_doc_map[page_content] for page_content in sorted_documents
|
|
||||||
]
|
|
||||||
|
|
||||||
return sorted_docs
|
return sorted_docs
|
||||||
|
Loading…
Reference in New Issue
Block a user