Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
af3330d162 rm encoding 2023-09-01 12:05:07 -07:00
lorenzofavaro
47fd3f75a2 Metadata encoder parameterization 2023-08-30 23:21:25 +02:00
lorenzofavaro
bec33a85bc Fix: Nested Dicts Handling of Document Metadata 2023-08-28 22:45:40 +02:00
2 changed files with 49 additions and 14 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
@@ -43,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
)
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
class MultiQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs.
Retrieve docs for each query. Return the unique union of all retrieved docs.
"""
retriever: BaseRetriever
@@ -85,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevated documents given a user query.
"""Get relevant documents given a user query.
Args:
question: user query
@@ -95,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
"""
queries = self.generate_queries(query, run_manager)
documents = self.retrieve_documents(queries, run_manager)
unique_documents = self.unique_union(documents)
return unique_documents
return self.unique_union(documents)
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
@@ -145,12 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of unique retrieved Documents
"""
# Create a dictionary with page_content as keys to remove duplicates
# TODO: Add Document ID property (e.g., UUID)
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc
for doc in documents
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
return _unique_documents(documents)

View File

@@ -0,0 +1,40 @@
from typing import List
import pytest as pytest
from langchain.retrievers.multi_query import _unique_documents
from langchain.schema import Document
@pytest.mark.parametrize(
"documents,expected",
[
([], []),
([Document(page_content="foo")], [Document(page_content="foo")]),
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
(
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
[Document(page_content="foo", metadata={"bar": "baz"})],
),
(
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
[Document(page_content="foo", metadata={"bar": [1, 2]})],
),
(
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
[Document(page_content="foo", metadata={"bar": {1, 2}})],
),
(
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
),
],
)
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
assert _unique_documents(documents) == expected