mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Qdrant metadata payload keys (#13001)
- **Description:** In Qdrant allows to input list of keys as the content_payload_key to retrieve multiple fields (the generated document will contain the dictionary {field: value} in a string), - **Issue:** Previously we were able to retrieve only one field from the vector database when making a search - **Dependencies:** - **Tag maintainer:** - **Twitter handle:** @jb_dlb --------- Co-authored-by: Jean Baptiste De La Broise <jeanbaptiste.delabroise@mdpi.com>
This commit is contained in:
parent
ad6dfb6220
commit
38813d7090
@ -82,8 +82,8 @@ class Qdrant(VectorStore):
|
|||||||
qdrant = Qdrant(client, collection_name, embedding_function)
|
qdrant = Qdrant(client, collection_name, embedding_function)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CONTENT_KEY = "page_content"
|
CONTENT_KEY = ["page_content"]
|
||||||
METADATA_KEY = "metadata"
|
METADATA_KEY = ["metadata"]
|
||||||
VECTOR_NAME = None
|
VECTOR_NAME = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -91,8 +91,8 @@ class Qdrant(VectorStore):
|
|||||||
client: Any,
|
client: Any,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
embeddings: Optional[Embeddings] = None,
|
embeddings: Optional[Embeddings] = None,
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: Union[list, str] = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: Union[list, str] = METADATA_KEY,
|
||||||
distance_strategy: str = "COSINE",
|
distance_strategy: str = "COSINE",
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
embedding_function: Optional[Callable] = None, # deprecated
|
embedding_function: Optional[Callable] = None, # deprecated
|
||||||
@ -112,6 +112,12 @@ class Qdrant(VectorStore):
|
|||||||
f"got {type(client)}"
|
f"got {type(client)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(content_payload_key, str): # Ensuring Backward compatibility
|
||||||
|
content_payload_key = [content_payload_key]
|
||||||
|
|
||||||
|
if isinstance(metadata_payload_key, str): # Ensuring Backward compatibility
|
||||||
|
metadata_payload_key = [metadata_payload_key]
|
||||||
|
|
||||||
if embeddings is None and embedding_function is None:
|
if embeddings is None and embedding_function is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
||||||
@ -127,8 +133,14 @@ class Qdrant(VectorStore):
|
|||||||
self._embeddings_function = embedding_function
|
self._embeddings_function = embedding_function
|
||||||
self.client: qdrant_client.QdrantClient = client
|
self.client: qdrant_client.QdrantClient = client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
self.content_payload_key = (
|
||||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
content_payload_key if content_payload_key is not None else self.CONTENT_KEY
|
||||||
|
)
|
||||||
|
self.metadata_payload_key = (
|
||||||
|
metadata_payload_key
|
||||||
|
if metadata_payload_key is not None
|
||||||
|
else self.METADATA_KEY
|
||||||
|
)
|
||||||
self.vector_name = vector_name or self.VECTOR_NAME
|
self.vector_name = vector_name or self.VECTOR_NAME
|
||||||
|
|
||||||
if embedding_function is not None:
|
if embedding_function is not None:
|
||||||
@ -1178,8 +1190,8 @@ class Qdrant(VectorStore):
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
distance_func: str = "Cosine",
|
distance_func: str = "Cosine",
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: List[str] = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: List[str] = METADATA_KEY,
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
batch_size: int = 64,
|
batch_size: int = 64,
|
||||||
shard_number: Optional[int] = None,
|
shard_number: Optional[int] = None,
|
||||||
@ -1354,8 +1366,8 @@ class Qdrant(VectorStore):
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
distance_func: str = "Cosine",
|
distance_func: str = "Cosine",
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: List[str] = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: List[str] = METADATA_KEY,
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
batch_size: int = 64,
|
batch_size: int = 64,
|
||||||
shard_number: Optional[int] = None,
|
shard_number: Optional[int] = None,
|
||||||
@ -1527,8 +1539,8 @@ class Qdrant(VectorStore):
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
distance_func: str = "Cosine",
|
distance_func: str = "Cosine",
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: List[str] = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: List[str] = METADATA_KEY,
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
shard_number: Optional[int] = None,
|
shard_number: Optional[int] = None,
|
||||||
replication_factor: Optional[int] = None,
|
replication_factor: Optional[int] = None,
|
||||||
@ -1691,8 +1703,8 @@ class Qdrant(VectorStore):
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
distance_func: str = "Cosine",
|
distance_func: str = "Cosine",
|
||||||
content_payload_key: str = CONTENT_KEY,
|
content_payload_key: List[str] = CONTENT_KEY,
|
||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: List[str] = METADATA_KEY,
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
shard_number: Optional[int] = None,
|
shard_number: Optional[int] = None,
|
||||||
replication_factor: Optional[int] = None,
|
replication_factor: Optional[int] = None,
|
||||||
@ -1888,11 +1900,11 @@ class Qdrant(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_payloads(
|
def _build_payloads(
|
||||||
cls,
|
cls: Type[Qdrant],
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[dict]],
|
metadatas: Optional[List[dict]],
|
||||||
content_payload_key: str,
|
content_payload_key: list[str],
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: list[str],
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
payloads = []
|
payloads = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
@ -1913,29 +1925,67 @@ class Qdrant(VectorStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _document_from_scored_point(
|
def _document_from_scored_point(
|
||||||
cls,
|
cls: Type[Qdrant],
|
||||||
scored_point: Any,
|
scored_point: Any,
|
||||||
content_payload_key: str,
|
content_payload_key: list[str],
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: list[str],
|
||||||
) -> Document:
|
) -> Document:
|
||||||
return Document(
|
payload = scored_point.payload
|
||||||
page_content=scored_point.payload.get(content_payload_key),
|
return Qdrant._document_from_payload(
|
||||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
payload=payload,
|
||||||
|
content_payload_key=content_payload_key,
|
||||||
|
metadata_payload_key=metadata_payload_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _document_from_scored_point_grpc(
|
def _document_from_scored_point_grpc(
|
||||||
cls,
|
cls: Type[Qdrant],
|
||||||
scored_point: Any,
|
scored_point: Any,
|
||||||
content_payload_key: str,
|
content_payload_key: list[str],
|
||||||
metadata_payload_key: str,
|
metadata_payload_key: list[str],
|
||||||
) -> Document:
|
) -> Document:
|
||||||
from qdrant_client.conversions.conversion import grpc_to_payload
|
from qdrant_client.conversions.conversion import grpc_to_payload
|
||||||
|
|
||||||
payload = grpc_to_payload(scored_point.payload)
|
payload = grpc_to_payload(scored_point.payload)
|
||||||
|
return Qdrant._document_from_payload(
|
||||||
|
payload=payload,
|
||||||
|
content_payload_key=content_payload_key,
|
||||||
|
metadata_payload_key=metadata_payload_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _document_from_payload(
|
||||||
|
cls: Type[Qdrant],
|
||||||
|
payload: Any,
|
||||||
|
content_payload_key: list[str],
|
||||||
|
metadata_payload_key: list[str],
|
||||||
|
) -> Document:
|
||||||
|
if len(content_payload_key) == 1:
|
||||||
|
content = payload.get(
|
||||||
|
content_payload_key
|
||||||
|
) # Ensuring backward compatibility
|
||||||
|
elif len(content_payload_key) > 1:
|
||||||
|
content = {
|
||||||
|
content_key: payload.get(content_key)
|
||||||
|
for content_key in content_payload_key
|
||||||
|
}
|
||||||
|
content = str(content) # Ensuring str type output
|
||||||
|
else:
|
||||||
|
content = ""
|
||||||
|
if len(metadata_payload_key) == 1:
|
||||||
|
metadata = payload.get(
|
||||||
|
metadata_payload_key
|
||||||
|
) # Ensuring backward compatibility
|
||||||
|
elif len(metadata_payload_key) > 1:
|
||||||
|
metadata = {
|
||||||
|
metadata_key: payload.get(metadata_key)
|
||||||
|
for metadata_key in metadata_payload_key
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metadata = {}
|
||||||
return Document(
|
return Document(
|
||||||
page_content=payload[content_payload_key],
|
page_content=content,
|
||||||
metadata=payload.get(metadata_payload_key) or {},
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
||||||
|
Loading…
Reference in New Issue
Block a user