mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Fix: Nested Dicts Handling of Document Metadata (#9880)
## Description When the `MultiQueryRetriever` is used to get the list of documents relevant according to a query, inside a vector store, and at least one of these contain metadata with nested dictionaries, a `TypeError: unhashable type: 'dict'` exception is thrown. This is caused by the `unique_union` function which, to guarantee the uniqueness of the returned documents, tries, unsuccessfully, to hash the nested dictionaries and use them as a part of key. ```python unique_documents_dict = { (doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in documents } ``` ## Issue #9872 (MultiQueryRetriever (get_relevant_documents) raises TypeError: unhashable type: 'dict' with dic metadata) ## Solution A possible solution is to dump the metadata dict to a string and use it as a part of hashed key. ```python unique_documents_dict = { (doc.page_content, json.dumps(doc.metadata, sort_keys=True)): doc for doc in documents } ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a52fe9528e
commit
00a7c31ffd
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Sequence
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
||||||
|
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
||||||
|
|
||||||
|
|
||||||
class MultiQueryRetriever(BaseRetriever):
|
class MultiQueryRetriever(BaseRetriever):
|
||||||
"""Given a query, use an LLM to write a set of queries.
|
"""Given a query, use an LLM to write a set of queries.
|
||||||
|
|
||||||
Retrieve docs for each query. Rake the unique union of all retrieved docs.
|
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
retriever: BaseRetriever
|
retriever: BaseRetriever
|
||||||
@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get relevated documents given a user query.
|
"""Get relevant documents given a user query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: user query
|
question: user query
|
||||||
@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
queries = self.generate_queries(query, run_manager)
|
queries = self.generate_queries(query, run_manager)
|
||||||
documents = self.retrieve_documents(queries, run_manager)
|
documents = self.retrieve_documents(queries, run_manager)
|
||||||
unique_documents = self.unique_union(documents)
|
return self.unique_union(documents)
|
||||||
return unique_documents
|
|
||||||
|
|
||||||
def generate_queries(
|
def generate_queries(
|
||||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||||
@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
Returns:
|
Returns:
|
||||||
List of unique retrieved Documents
|
List of unique retrieved Documents
|
||||||
"""
|
"""
|
||||||
# Create a dictionary with page_content as keys to remove duplicates
|
return _unique_documents(documents)
|
||||||
# TODO: Add Document ID property (e.g., UUID)
|
|
||||||
unique_documents_dict = {
|
|
||||||
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc
|
|
||||||
for doc in documents
|
|
||||||
}
|
|
||||||
|
|
||||||
unique_documents = list(unique_documents_dict.values())
|
|
||||||
return unique_documents
|
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest as pytest
|
||||||
|
|
||||||
|
from langchain.retrievers.multi_query import _unique_documents
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"documents,expected",
|
||||||
|
[
|
||||||
|
([], []),
|
||||||
|
([Document(page_content="foo")], [Document(page_content="foo")]),
|
||||||
|
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
|
||||||
|
(
|
||||||
|
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
|
||||||
|
[Document(page_content="foo", metadata={"bar": "baz"})],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
|
||||||
|
[Document(page_content="foo", metadata={"bar": [1, 2]})],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
|
||||||
|
[Document(page_content="foo", metadata={"bar": {1, 2}})],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[
|
||||||
|
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||||
|
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||||
|
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
|
||||||
|
assert _unique_documents(documents) == expected
|
Loading…
Reference in New Issue
Block a user