mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
Add support for Qdrant nested filter (#4354)
# Add support for Qdrant nested filter This extends the filter functionality for the Qdrant vectorstore. The current filter implementation is limited to a single-level metadata structure; however, Qdrant supports nested metadata filtering. This extends the functionality for users to maximize the filter functionality when using Qdrant as the vectorstore. Reference: https://qdrant.tech/documentation/filtering/#nested-key --------- Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
This commit is contained in:
parent
872605a5c5
commit
6335cb5b3a
@ -5,7 +5,18 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -14,7 +25,11 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from langchain.vectorstores import VectorStore
|
from langchain.vectorstores import VectorStore
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
MetadataFilter = Dict[str, Union[str, int, bool]]
|
if TYPE_CHECKING:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
|
||||||
|
MetadataFilter = Dict[str, Union[str, int, bool, dict, list]]
|
||||||
|
|
||||||
|
|
||||||
class Qdrant(VectorStore):
|
class Qdrant(VectorStore):
|
||||||
@ -461,18 +476,42 @@ class Qdrant(VectorStore):
|
|||||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any:
|
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
||||||
if filter is None or 0 == len(filter):
|
|
||||||
return None
|
|
||||||
|
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
return rest.Filter(
|
out = []
|
||||||
must=[
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for _key, value in value.items():
|
||||||
|
out.extend(self._build_condition(f"{key}.{_key}", value))
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for _value in value:
|
||||||
|
if isinstance(_value, dict):
|
||||||
|
out.extend(self._build_condition(f"{key}[]", _value))
|
||||||
|
else:
|
||||||
|
out.extend(self._build_condition(f"{key}", _value))
|
||||||
|
else:
|
||||||
|
out.append(
|
||||||
rest.FieldCondition(
|
rest.FieldCondition(
|
||||||
key=f"{self.metadata_payload_key}.{key}",
|
key=f"{self.metadata_payload_key}.{key}",
|
||||||
match=rest.MatchValue(value=value),
|
match=rest.MatchValue(value=value),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _qdrant_filter_from_dict(
|
||||||
|
self, filter: Optional[MetadataFilter]
|
||||||
|
) -> Optional[rest.Filter]:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
if not filter:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return rest.Filter(
|
||||||
|
must=[
|
||||||
|
condition
|
||||||
for key, value in filter.items()
|
for key, value in filter.items()
|
||||||
|
for condition in self._build_condition(key, value)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -78,15 +78,26 @@ def test_qdrant_with_metadatas(
|
|||||||
def test_qdrant_similarity_search_filters() -> None:
|
def test_qdrant_similarity_search_filters() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
metadatas = [{"page": i} for i in range(len(texts))]
|
metadatas = [
|
||||||
|
{"page": i, "metadata": {"page": i + 1, "pages": [i + 2, -1]}}
|
||||||
|
for i in range(len(texts))
|
||||||
|
]
|
||||||
docsearch = Qdrant.from_texts(
|
docsearch = Qdrant.from_texts(
|
||||||
texts,
|
texts,
|
||||||
FakeEmbeddings(),
|
FakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
location=":memory:",
|
location=":memory:",
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
|
|
||||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=1, filter={"page": 1, "metadata": {"page": 2, "pages": [3]}}
|
||||||
|
)
|
||||||
|
assert output == [
|
||||||
|
Document(
|
||||||
|
page_content="bar",
|
||||||
|
metadata={"page": 1, "metadata": {"page": 2, "pages": [3, -1]}},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
Reference in New Issue
Block a user