mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-12 01:46:12 +00:00
langchain: add id_key option to EnsembleRetriever for metadata-based document merging (#22950)
**Description:** - What I changed - By specifying the `id_key` during the initialization of `EnsembleRetriever`, it is now possible to determine which documents to merge scores for based on the value corresponding to the `id_key` element in the metadata, instead of `page_content`. Below is an example of how to use the modified `EnsembleRetriever`: ```python retriever = EnsembleRetriever(retrievers=[ret1, ret2], id_key="id") # The Document returned by each retriever must keep the "id" key in its metadata. ``` - Additionally, I added a script to easily test the behavior of the `invoke` method of the modified `EnsembleRetriever`. - Why I changed - There are cases where you may want to calculate scores by treating Documents with different `page_content` as the same when using `EnsembleRetriever`. For example, when you want to ensemble the search results of the same document described in two different languages. - The previous `EnsembleRetriever` used `page_content` as the basis for score aggregation, making the above usage difficult. Therefore, the score is now calculated based on the specified key value in the Document's metadata. **Twitter handle:** @shimajiroxyz
This commit is contained in:
parent
39f6c4169d
commit
3e835a1aa1
@ -66,11 +66,14 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
c: A constant added to the rank, controlling the balance between the importance
|
c: A constant added to the rank, controlling the balance between the importance
|
||||||
of high-ranked items and the consideration given to lower-ranked items.
|
of high-ranked items and the consideration given to lower-ranked items.
|
||||||
Default is 60.
|
Default is 60.
|
||||||
|
id_key: The key in the document's metadata used to determine unique documents.
|
||||||
|
If not specified, page_content is used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
retrievers: List[RetrieverLike]
|
retrievers: List[RetrieverLike]
|
||||||
weights: List[float]
|
weights: List[float]
|
||||||
c: int = 60
|
c: int = 60
|
||||||
|
id_key: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||||
@ -305,13 +308,24 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
rrf_score: Dict[str, float] = defaultdict(float)
|
rrf_score: Dict[str, float] = defaultdict(float)
|
||||||
for doc_list, weight in zip(doc_lists, self.weights):
|
for doc_list, weight in zip(doc_lists, self.weights):
|
||||||
for rank, doc in enumerate(doc_list, start=1):
|
for rank, doc in enumerate(doc_list, start=1):
|
||||||
rrf_score[doc.page_content] += weight / (rank + self.c)
|
rrf_score[
|
||||||
|
doc.page_content
|
||||||
|
if self.id_key is None
|
||||||
|
else doc.metadata[self.id_key]
|
||||||
|
] += weight / (rank + self.c)
|
||||||
|
|
||||||
# Docs are deduplicated by their contents then sorted by their scores
|
# Docs are deduplicated by their contents then sorted by their scores
|
||||||
all_docs = chain.from_iterable(doc_lists)
|
all_docs = chain.from_iterable(doc_lists)
|
||||||
sorted_docs = sorted(
|
sorted_docs = sorted(
|
||||||
unique_by_key(all_docs, lambda doc: doc.page_content),
|
unique_by_key(
|
||||||
|
all_docs,
|
||||||
|
lambda doc: doc.page_content
|
||||||
|
if self.id_key is None
|
||||||
|
else doc.metadata[self.id_key],
|
||||||
|
),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
key=lambda doc: rrf_score[doc.page_content],
|
key=lambda doc: rrf_score[
|
||||||
|
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
return sorted_docs
|
return sorted_docs
|
||||||
|
88
libs/langchain/tests/unit_tests/retrievers/test_ensemble.py
Normal file
88
libs/langchain/tests/unit_tests/retrievers/test_ensemble.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class MockRetriever(BaseRetriever):
|
||||||
|
docs: List[Document]
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return the documents"""
|
||||||
|
return self.docs
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke() -> None:
|
||||||
|
documents1 = [
|
||||||
|
Document(page_content="a", metadata={"id": 1}),
|
||||||
|
Document(page_content="b", metadata={"id": 2}),
|
||||||
|
Document(page_content="c", metadata={"id": 3}),
|
||||||
|
]
|
||||||
|
documents2 = [Document(page_content="b")]
|
||||||
|
|
||||||
|
retriever1 = MockRetriever(docs=documents1)
|
||||||
|
retriever2 = MockRetriever(docs=documents2)
|
||||||
|
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None
|
||||||
|
)
|
||||||
|
ranked_documents = ensemble_retriever.invoke("_")
|
||||||
|
|
||||||
|
# The document with page_content "b" in documents2
|
||||||
|
# will be merged with the document with page_content "b"
|
||||||
|
# in documents1, so the length of ranked_documents should be 3.
|
||||||
|
# Additionally, the document with page_content "b" will be ranked 1st.
|
||||||
|
assert len(ranked_documents) == 3
|
||||||
|
assert ranked_documents[0].page_content == "b"
|
||||||
|
|
||||||
|
documents1 = [
|
||||||
|
Document(page_content="a", metadata={"id": 1}),
|
||||||
|
Document(page_content="b", metadata={"id": 2}),
|
||||||
|
Document(page_content="c", metadata={"id": 3}),
|
||||||
|
]
|
||||||
|
documents2 = [Document(page_content="d")]
|
||||||
|
|
||||||
|
retriever1 = MockRetriever(docs=documents1)
|
||||||
|
retriever2 = MockRetriever(docs=documents2)
|
||||||
|
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None
|
||||||
|
)
|
||||||
|
ranked_documents = ensemble_retriever.invoke("_")
|
||||||
|
|
||||||
|
# The document with page_content "d" in documents2 will not be merged
|
||||||
|
# with any document in documents1, so the length of ranked_documents
|
||||||
|
# should be 4. The document with page_content "a" and the document
|
||||||
|
# with page_content "d" will have the same score, but the document
|
||||||
|
# with page_content "a" will be ranked 1st because retriever1 has a smaller index.
|
||||||
|
assert len(ranked_documents) == 4
|
||||||
|
assert ranked_documents[0].page_content == "a"
|
||||||
|
|
||||||
|
documents1 = [
|
||||||
|
Document(page_content="a", metadata={"id": 1}),
|
||||||
|
Document(page_content="b", metadata={"id": 2}),
|
||||||
|
Document(page_content="c", metadata={"id": 3}),
|
||||||
|
]
|
||||||
|
documents2 = [Document(page_content="d", metadata={"id": 2})]
|
||||||
|
|
||||||
|
retriever1 = MockRetriever(docs=documents1)
|
||||||
|
retriever2 = MockRetriever(docs=documents2)
|
||||||
|
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key="id"
|
||||||
|
)
|
||||||
|
ranked_documents = ensemble_retriever.invoke("_")
|
||||||
|
|
||||||
|
# Since id_key is specified, the document with id 2 will be merged.
|
||||||
|
# Therefore, the length of ranked_documents should be 3.
|
||||||
|
# Additionally, the document with page_content "b" will be ranked 1st.
|
||||||
|
assert len(ranked_documents) == 3
|
||||||
|
assert ranked_documents[0].page_content == "b"
|
Loading…
Reference in New Issue
Block a user