Metadata encoder parameterization

This commit is contained in:
lorenzofavaro
2023-08-30 23:21:25 +02:00
parent bec33a85bc
commit 47fd3f75a2

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import List
from functools import partial
from typing import Callable, List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
@@ -137,7 +138,11 @@ class MultiQueryRetriever(BaseRetriever):
documents.extend(docs)
return documents
def unique_union(self, documents: List[Document]) -> List[Document]:
def unique_union(
self,
documents: List[Document],
metadata_encoder: Callable = partial(json.dumps, sort_keys=True),
) -> List[Document]:
"""Get unique Documents.
Args:
@@ -149,8 +154,7 @@ class MultiQueryRetriever(BaseRetriever):
# Create a dictionary with page_content as keys to remove duplicates
# TODO: Add Document ID property (e.g., UUID)
unique_documents_dict = {
(doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc
for doc in documents
(doc.page_content, metadata_encoder(doc.metadata)): doc for doc in documents
}
unique_documents = list(unique_documents_dict.values())