langchain[patch]: Simplify ensemble retriever (#14427)

- **Description:** code simplification to improve readability and remove
unnecessary memory allocations.
  - **Tag maintainer**: @baskaryan, @eyurtsev, @hwchase17.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Ahmed Moubtahij 2024-03-29 19:49:49 -04:00 committed by GitHub
parent b36f4147b0
commit f5d4ce840f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,20 @@ Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion multiple retrievers by using weighted Reciprocal Rank Fusion
""" """
import asyncio import asyncio
from typing import Any, Dict, List, Optional, cast from collections import defaultdict
from collections.abc import Hashable
from itertools import chain
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
TypeVar,
cast,
)
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -20,6 +33,17 @@ from langchain_core.runnables.utils import (
get_unique_config_specs, get_unique_config_specs,
) )
T = TypeVar("T")
H = TypeVar("H", bound=Hashable)
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
seen = set()
for e in iterable:
if (k := key(e)) not in seen:
seen.add(k)
yield e
class EnsembleRetriever(BaseRetriever): class EnsembleRetriever(BaseRetriever):
"""Retriever that ensembles the multiple retrievers. """Retriever that ensembles the multiple retrievers.
@ -267,32 +291,18 @@ class EnsembleRetriever(BaseRetriever):
"Number of rank lists must be equal to the number of weights." "Number of rank lists must be equal to the number of weights."
) )
# Create a union of all unique documents in the input doc_lists # Associate each doc's content with its RRF score for later sorting by it
all_documents = set() # Duplicated contents across retrievers are collapsed & scored cumulatively
for doc_list in doc_lists: rrf_score: Dict[str, float] = defaultdict(float)
for doc in doc_list:
all_documents.add(doc.page_content)
# Initialize the RRF score dictionary for each document
rrf_score_dic = {doc: 0.0 for doc in all_documents}
# Calculate RRF scores for each document
for doc_list, weight in zip(doc_lists, self.weights): for doc_list, weight in zip(doc_lists, self.weights):
for rank, doc in enumerate(doc_list, start=1): for rank, doc in enumerate(doc_list, start=1):
rrf_score = weight * (1 / (rank + self.c)) rrf_score[doc.page_content] += weight / (rank + self.c)
rrf_score_dic[doc.page_content] += rrf_score
# Sort documents by their RRF scores in descending order # Docs are deduplicated by their contents then sorted by their scores
sorted_documents = sorted( all_docs = chain.from_iterable(doc_lists)
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True sorted_docs = sorted(
unique_by_key(all_docs, lambda doc: doc.page_content),
reverse=True,
key=lambda doc: rrf_score[doc.page_content],
) )
# Map the sorted page_content back to the original document objects
page_content_to_doc_map = {
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
}
sorted_docs = [
page_content_to_doc_map[page_content] for page_content in sorted_documents
]
return sorted_docs return sorted_docs