From 47fd3f75a258760c19d9f6a7e52589f3128c690f Mon Sep 17 00:00:00 2001 From: lorenzofavaro <44714920+lorenzofavaro@users.noreply.github.com> Date: Wed, 30 Aug 2023 23:21:25 +0200 Subject: [PATCH] Metadata encoder parameterization --- libs/langchain/langchain/retrievers/multi_query.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index c98e6f9dbce..79ed26d1375 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,6 +1,7 @@ import json import logging -from typing import List +from functools import partial +from typing import Callable, List from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.llm import LLMChain @@ -137,7 +138,11 @@ class MultiQueryRetriever(BaseRetriever): documents.extend(docs) return documents - def unique_union(self, documents: List[Document]) -> List[Document]: + def unique_union( + self, + documents: List[Document], + metadata_encoder: Callable = partial(json.dumps, sort_keys=True), + ) -> List[Document]: """Get unique Documents. Args: @@ -149,8 +154,7 @@ class MultiQueryRetriever(BaseRetriever): # Create a dictionary with page_content as keys to remove duplicates # TODO: Add Document ID property (e.g., UUID) unique_documents_dict = { - (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc - for doc in documents + (doc.page_content, metadata_encoder(doc.metadata)): doc for doc in documents } unique_documents = list(unique_documents_dict.values())