From f5d4ce840f675bf47420c523245565ea678e3035 Mon Sep 17 00:00:00 2001 From: Ahmed Moubtahij <50707385+Ayenem@users.noreply.github.com> Date: Fri, 29 Mar 2024 19:49:49 -0400 Subject: [PATCH] langchain[patch]: Simplify ensemble retriever (#14427) - **Description:** code simplification to improve readability and remove unnecessary memory allocations. - **Tag maintainer**: @baskaryan, @eyurtsev, @hwchase17. --------- Co-authored-by: Bagatur --- .../langchain/retrievers/ensemble.py | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 835faba78a4..937a5e5549b 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -1,9 +1,22 @@ """ -Ensemble retriever that ensemble the results of +Ensemble retriever that ensemble the results of multiple retrievers by using weighted Reciprocal Rank Fusion """ import asyncio -from typing import Any, Dict, List, Optional, cast +from collections import defaultdict +from collections.abc import Hashable +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + TypeVar, + cast, +) from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, @@ -20,6 +33,17 @@ from langchain_core.runnables.utils import ( get_unique_config_specs, ) +T = TypeVar("T") +H = TypeVar("H", bound=Hashable) + + +def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]: + seen = set() + for e in iterable: + if (k := key(e)) not in seen: + seen.add(k) + yield e + class EnsembleRetriever(BaseRetriever): """Retriever that ensembles the multiple retrievers. @@ -267,32 +291,18 @@ class EnsembleRetriever(BaseRetriever): "Number of rank lists must be equal to the number of weights." ) - # Create a union of all unique documents in the input doc_lists - all_documents = set() - for doc_list in doc_lists: - for doc in doc_list: - all_documents.add(doc.page_content) - - # Initialize the RRF score dictionary for each document - rrf_score_dic = {doc: 0.0 for doc in all_documents} - - # Calculate RRF scores for each document + # Associate each doc's content with its RRF score for later sorting by it + # Duplicated contents across retrievers are collapsed & scored cumulatively + rrf_score: Dict[str, float] = defaultdict(float) for doc_list, weight in zip(doc_lists, self.weights): for rank, doc in enumerate(doc_list, start=1): - rrf_score = weight * (1 / (rank + self.c)) - rrf_score_dic[doc.page_content] += rrf_score + rrf_score[doc.page_content] += weight / (rank + self.c) - # Sort documents by their RRF scores in descending order - sorted_documents = sorted( - rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True + # Docs are deduplicated by their contents then sorted by their scores + all_docs = chain.from_iterable(doc_lists) + sorted_docs = sorted( + unique_by_key(all_docs, lambda doc: doc.page_content), + reverse=True, + key=lambda doc: rrf_score[doc.page_content], ) - - # Map the sorted page_content back to the original document objects - page_content_to_doc_map = { - doc.page_content: doc for doc_list in doc_lists for doc in doc_list - } - sorted_docs = [ - page_content_to_doc_map[page_content] for page_content in sorted_documents - ] - return sorted_docs