mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
rm encoding
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, List
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -45,10 +43,14 @@ DEFAULT_QUERY_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
||||
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
||||
|
||||
|
||||
class MultiQueryRetriever(BaseRetriever):
|
||||
"""Given a query, use an LLM to write a set of queries.
|
||||
|
||||
Retrieve docs for each query. Rake the unique union of all retrieved docs.
|
||||
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever
|
||||
@@ -87,7 +89,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Get relevated documents given a user query.
|
||||
"""Get relevant documents given a user query.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
@@ -97,8 +99,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
"""
|
||||
queries = self.generate_queries(query, run_manager)
|
||||
documents = self.retrieve_documents(queries, run_manager)
|
||||
unique_documents = self.unique_union(documents)
|
||||
return unique_documents
|
||||
return self.unique_union(documents)
|
||||
|
||||
def generate_queries(
|
||||
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
||||
@@ -138,11 +139,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
documents.extend(docs)
|
||||
return documents
|
||||
|
||||
def unique_union(
|
||||
self,
|
||||
documents: List[Document],
|
||||
metadata_encoder: Callable = partial(json.dumps, sort_keys=True),
|
||||
) -> List[Document]:
|
||||
def unique_union(self, documents: List[Document]) -> List[Document]:
|
||||
"""Get unique Documents.
|
||||
|
||||
Args:
|
||||
@@ -151,11 +148,4 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List of unique retrieved Documents
|
||||
"""
|
||||
# Create a dictionary with page_content as keys to remove duplicates
|
||||
# TODO: Add Document ID property (e.g., UUID)
|
||||
unique_documents_dict = {
|
||||
(doc.page_content, metadata_encoder(doc.metadata)): doc for doc in documents
|
||||
}
|
||||
|
||||
unique_documents = list(unique_documents_dict.values())
|
||||
return unique_documents
|
||||
return _unique_documents(documents)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import List
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.retrievers.multi_query import _unique_documents
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"documents,expected",
|
||||
[
|
||||
([], []),
|
||||
([Document(page_content="foo")], [Document(page_content="foo")]),
|
||||
([Document(page_content="foo")] * 2, [Document(page_content="foo")]),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": "baz"})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": "baz"})],
|
||||
),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": [1, 2]})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": [1, 2]})],
|
||||
),
|
||||
(
|
||||
[Document(page_content="foo", metadata={"bar": {1, 2}})] * 2,
|
||||
[Document(page_content="foo", metadata={"bar": {1, 2}})],
|
||||
),
|
||||
(
|
||||
[
|
||||
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||
],
|
||||
[
|
||||
Document(page_content="foo", metadata={"bar": [1, 2]}),
|
||||
Document(page_content="foo", metadata={"bar": [2, 1]}),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__unique_documents(documents: List[Document], expected: List[Document]) -> None:
|
||||
assert _unique_documents(documents) == expected
|
||||
Reference in New Issue
Block a user