rm encoding

This commit is contained in:
Bagatur
2023-09-01 12:05:07 -07:00
parent 47fd3f75a2
commit af3330d162
2 changed files with 50 additions and 20 deletions

View File

@@ -1,7 +1,5 @@
import json
import logging
from functools import partial
from typing import Callable, List
from typing import List, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
@@ -45,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
)
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
class MultiQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs.
Retrieve docs for each query. Return the unique union of all retrieved docs.
"""
retriever: BaseRetriever
@@ -87,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevated documents given a user query.
"""Get relevant documents given a user query.
Args:
question: user query
@@ -97,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
"""
queries = self.generate_queries(query, run_manager)
documents = self.retrieve_documents(queries, run_manager)
unique_documents = self.unique_union(documents)
return unique_documents
return self.unique_union(documents)
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
@@ -138,11 +139,7 @@ class MultiQueryRetriever(BaseRetriever):
documents.extend(docs)
return documents
def unique_union(
self,
documents: List[Document],
metadata_encoder: Callable = partial(json.dumps, sort_keys=True),
) -> List[Document]:
def unique_union(self, documents: List[Document]) -> List[Document]:
"""Get unique Documents.
Args:
@@ -151,11 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of unique retrieved Documents
"""
# Create a dictionary with page_content as keys to remove duplicates
# TODO: Add Document ID property (e.g., UUID)
unique_documents_dict = {
(doc.page_content, metadata_encoder(doc.metadata)): doc for doc in documents
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
return _unique_documents(documents)

View File

@@ -0,0 +1,40 @@
from typing import List
import pytest as pytest
from langchain.retrievers.multi_query import _unique_documents
from langchain.schema import Document
@pytest.mark.parametrize(
"documents,expected",
[
([], []),
([Document(page_content="foo")], [Document(page_content="foo")]),
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
(
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
[Document(page_content="foo", metadata={"bar": "baz"})],
),
(
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
[Document(page_content="foo", metadata={"bar": [1, 2]})],
),
(
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
[Document(page_content="foo", metadata={"bar": {1, 2}})],
),
(
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
[
Document(page_content="foo", metadata={"bar": [1, 2]}),
Document(page_content="foo", metadata={"bar": [2, 1]}),
],
),
],
)
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
assert _unique_documents(documents) == expected