diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index c98e6f9dbce..79ed26d1375 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,6 +1,7 @@ import json import logging -from typing import List +from functools import partial +from typing import Callable, List from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.llm import LLMChain @@ -137,7 +138,11 @@ class MultiQueryRetriever(BaseRetriever): documents.extend(docs) return documents - def unique_union(self, documents: List[Document]) -> List[Document]: + def unique_union( + self, + documents: List[Document], + metadata_encoder: Callable = partial(json.dumps, sort_keys=True), + ) -> List[Document]: """Get unique Documents. Args: @@ -149,8 +154,7 @@ class MultiQueryRetriever(BaseRetriever): # Create a dictionary with page_content as keys to remove duplicates # TODO: Add Document ID property (e.g., UUID) unique_documents_dict = { - (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc - for doc in documents + (doc.page_content, metadata_encoder(doc.metadata)): doc for doc in documents } unique_documents = list(unique_documents_dict.values())