diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 8398ce40d9d..b99bb84a7ea 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Sequence from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.chains.llm import LLMChain @@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate( ) +def _unique_documents(documents: Sequence[Document]) -> List[Document]: + return [doc for i, doc in enumerate(documents) if doc not in documents[:i]] + + class MultiQueryRetriever(BaseRetriever): """Given a query, use an LLM to write a set of queries. - Retrieve docs for each query. Rake the unique union of all retrieved docs. + Retrieve docs for each query. Return the unique union of all retrieved docs. """ retriever: BaseRetriever @@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever): *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: - """Get relevated documents given a user query. + """Get relevant documents given a user query. Args: question: user query @@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever): """ queries = self.generate_queries(query, run_manager) documents = self.retrieve_documents(queries, run_manager) - unique_documents = self.unique_union(documents) - return unique_documents + return self.unique_union(documents) def generate_queries( self, question: str, run_manager: CallbackManagerForRetrieverRun @@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever): Returns: List of unique retrieved Documents """ - # Create a dictionary with page_content as keys to remove duplicates - # TODO: Add Document ID property (e.g., UUID) - unique_documents_dict = { - (doc.page_content, tuple(sorted(doc.metadata.items()))): doc - for doc in documents - } - - unique_documents = list(unique_documents_dict.values()) - return unique_documents + return _unique_documents(documents) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py new file mode 100644 index 00000000000..978950ec58a --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -0,0 +1,40 @@ +from typing import List + +import pytest as pytest + +from langchain.retrievers.multi_query import _unique_documents +from langchain.schema import Document + + +@pytest.mark.parametrize( + "documents,expected", + [ + ([], []), + ([Document(page_content="foo")], [Document(page_content="foo")]), + ([Document(page_content="foo")] * 2, [Document(page_content="foo")]), + ( + [Document(page_content="foo", metadata={"bar": "baz"})] * 2, + [Document(page_content="foo", metadata={"bar": "baz"})], + ), + ( + [Document(page_content="foo", metadata={"bar": [1, 2]})] * 2, + [Document(page_content="foo", metadata={"bar": [1, 2]})], + ), + ( + [Document(page_content="foo", metadata={"bar": {1, 2}})] * 2, + [Document(page_content="foo", metadata={"bar": {1, 2}})], + ), + ( + [ + Document(page_content="foo", metadata={"bar": [1, 2]}), + Document(page_content="foo", metadata={"bar": [2, 1]}), + ], + [ + Document(page_content="foo", metadata={"bar": [1, 2]}), + Document(page_content="foo", metadata={"bar": [2, 1]}), + ], + ), + ], +) +def test__unique_documents(documents: List[Document], expected: List[Document]) -> None: + assert _unique_documents(documents) == expected