mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
Fix: Nested Dicts Handling of Document Metadata (#9880)
## Description When the `MultiQueryRetriever` is used to get the list of documents relevant according to a query, inside a vector store, and at least one of these contain metadata with nested dictionaries, a `TypeError: unhashable type: 'dict'` exception is thrown. This is caused by the `unique_union` function which, to guarantee the uniqueness of the returned documents, tries, unsuccessfully, to hash the nested dictionaries and use them as a part of key. ```python unique_documents_dict = { (doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in documents } ``` ## Issue #9872 (MultiQueryRetriever (get_relevant_documents) raises TypeError: unhashable type: 'dict' with dic metadata) ## Solution A possible solution is to dump the metadata dict to a string and use it as a part of hashed key. ```python unique_documents_dict = { (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc for doc in documents } ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a52fe9528e
commit
00a7c31ffd
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
||||
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
||||
|
||||
|
||||
class MultiQueryRetriever(BaseRetriever):
|
||||
"""Given a query, use an LLM to write a set of queries.
|
||||
|
||||
Retrieve docs for each query. Rake the unique union of all retrieved docs.
|
||||
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever
|
||||
@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Get relevated documents given a user query.
|
||||
"""Get relevant documents given a user query.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
queries = self.generate_queries(query, run_manager)
|
||||
documents = self.retrieve_documents(queries, run_manager)
|
||||
unique_documents = self.unique_union(documents)
|
||||
return unique_documents
|
||||
return self.unique_union(documents)
|
||||
|
||||
def generate_queries(
|
||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||
@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of unique retrieved Documents
|
||||
"""
|
||||
# Create a dictionary with page_content as keys to remove duplicates
|
||||
# TODO: Add Document ID property (e.g., UUID)
|
||||
unique_documents_dict = {
|
||||
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc
|
||||
for doc in documents
|
||||
}
|
||||
|
||||
unique_documents = list(unique_documents_dict.values())
|
||||
return unique_documents
|
||||
return _unique_documents(documents)
|
||||
|
@ -0,0 +1,40 @@
|
||||
from typing import List
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.retrievers.multi_query import _unique_documents
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"documents,expected",
|
||||
[
|
||||
([], []),
|
||||
([Document(page_content="foo")], [Document(page_content="foo")]),
|
||||
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": "baz"})],
|
||||
),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": [1, 2]})],
|
||||
),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": {1, 2}})],
|
||||
),
|
||||
(
|
||||
[
|
||||
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||
],
|
||||
[
|
||||
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
|
||||
assert _unique_documents(documents) == expected
|
Loading…
Reference in New Issue
Block a user