diff --git a/libs/community/langchain_community/vectorstores/qdrant.py b/libs/community/langchain_community/vectorstores/qdrant.py index 7c18da82ec0..5073943819b 100644 --- a/libs/community/langchain_community/vectorstores/qdrant.py +++ b/libs/community/langchain_community/vectorstores/qdrant.py @@ -22,11 +22,11 @@ from typing import ( ) import numpy as np -from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor from langchain_core.vectorstores import VectorStore +from langchain_community.docstore.document import Document from langchain_community.vectorstores.utils import maximal_marginal_relevance if TYPE_CHECKING: @@ -94,6 +94,7 @@ class Qdrant(VectorStore): metadata_payload_key: str = METADATA_KEY, distance_strategy: str = "COSINE", vector_name: Optional[str] = VECTOR_NAME, + async_client: Optional[Any] = None, embedding_function: Optional[Callable] = None, # deprecated ): """Initialize with necessary components.""" @@ -111,6 +112,14 @@ class Qdrant(VectorStore): f"got {type(client)}" ) + if async_client is not None and not isinstance( + async_client, qdrant_client.AsyncQdrantClient + ): + raise ValueError( + f"async_client should be an instance of qdrant_client.AsyncQdrantClient" + f"got {type(async_client)}" + ) + if embeddings is None and embedding_function is None: raise ValueError( "`embeddings` value can't be None. Pass `Embeddings` instance." @@ -125,6 +134,7 @@ class Qdrant(VectorStore): self._embeddings = embeddings self._embeddings_function = embedding_function self.client: qdrant_client.QdrantClient = client + self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client self.collection_name = collection_name self.content_payload_key = content_payload_key or self.CONTENT_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY @@ -208,18 +218,21 @@ class Qdrant(VectorStore): Returns: List of ids from adding the texts into the vectorstore. """ - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc + from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal + + if self.async_client is None or isinstance( + self.async_client._client, AsyncQdrantLocal + ): + raise NotImplementedError( + "QdrantLocal cannot interoperate with sync and async clients" + ) added_ids = [] async for batch_ids, points in self._agenerate_rest_batches( texts, metadatas, ids, batch_size ): - await self.client.async_grpc_points.Upsert( - grpc.UpsertPoints( - collection_name=self.collection_name, - points=[RestToGrpc.convert_point_struct(point) for point in points], - ) + await self.async_client.upsert( + collection_name=self.collection_name, points=points, **kwargs ) added_ids.extend(batch_ids) @@ -399,7 +412,7 @@ class Qdrant(VectorStore): - 'all' - query all replicas, and return values present in all replicas **kwargs: Any other named arguments to pass through to - QdrantClient.async_grpc_points.Search(). + AsyncQdrantClient.Search(). Returns: List of documents most similar to the query text and distance for each. @@ -514,7 +527,7 @@ class Qdrant(VectorStore): - 'all' - query all replicas, and return values present in all replicas **kwargs: Any other named arguments to pass through to - QdrantClient.async_grpc_points.Search(). + AsyncQdrantClient.Search(). Returns: List of Documents most similar to the query. @@ -614,56 +627,6 @@ class Qdrant(VectorStore): for result in results ] - async def _asearch_with_score_by_vector( - self, - embedding: List[float], - *, - k: int = 4, - filter: Optional[MetadataFilter] = None, - search_params: Optional[common_types.SearchParams] = None, - offset: int = 0, - score_threshold: Optional[float] = None, - consistency: Optional[common_types.ReadConsistency] = None, - with_vectors: bool = False, - **kwargs: Any, - ) -> Any: - """Return results most similar to embedding vector.""" - from qdrant_client import grpc # noqa - from qdrant_client.conversions.conversion import RestToGrpc - from qdrant_client.http import models as rest - - if filter is not None and isinstance(filter, dict): - warnings.warn( - "Using dict as a `filter` is deprecated. Please use qdrant-client " - "filters directly: " - "https://qdrant.tech/documentation/concepts/filtering/", - DeprecationWarning, - ) - qdrant_filter = self._qdrant_filter_from_dict(filter) - else: - qdrant_filter = filter - - if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter): - qdrant_filter = RestToGrpc.convert_filter(qdrant_filter) - - response = await self.client.async_grpc_points.Search( - grpc.SearchPoints( - collection_name=self.collection_name, - vector_name=self.vector_name, - vector=embedding, - filter=qdrant_filter, - params=search_params, - limit=k, - offset=offset, - with_payload=grpc.WithPayloadSelector(enable=True), - with_vectors=grpc.WithVectorsSelector(enable=with_vectors), - score_threshold=score_threshold, - read_consistency=consistency, - **kwargs, - ) - ) - return response - @sync_call_fallback async def asimilarity_search_with_score_by_vector( self, @@ -706,30 +669,55 @@ class Qdrant(VectorStore): - 'all' - query all replicas, and return values present in all replicas **kwargs: Any other named arguments to pass through to - QdrantClient.async_grpc_points.Search(). + AsyncQdrantClient.Search(). Returns: List of documents most similar to the query text and distance for each. """ - response = await self._asearch_with_score_by_vector( - embedding, - k=k, - filter=filter, + from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal + + if self.async_client is None or isinstance( + self.async_client._client, AsyncQdrantLocal + ): + raise NotImplementedError( + "QdrantLocal cannot interoperate with sync and async clients" + ) + if filter is not None and isinstance(filter, dict): + warnings.warn( + "Using dict as a `filter` is deprecated. Please use qdrant-client " + "filters directly: " + "https://qdrant.tech/documentation/concepts/filtering/", + DeprecationWarning, + ) + qdrant_filter = self._qdrant_filter_from_dict(filter) + else: + qdrant_filter = filter + + query_vector = embedding + if self.vector_name is not None: + query_vector = (self.vector_name, embedding) # type: ignore[assignment] + + results = await self.async_client.search( + collection_name=self.collection_name, + query_vector=query_vector, + query_filter=qdrant_filter, search_params=search_params, + limit=k, offset=offset, + with_payload=True, + with_vectors=False, # Langchain does not expect vectors to be returned score_threshold=score_threshold, consistency=consistency, **kwargs, ) - return [ ( - self._document_from_scored_point_grpc( + self._document_from_scored_point( result, self.content_payload_key, self.metadata_payload_key ), result.score, ) - for result in response.result + for result in results ] def max_marginal_relevance_search( @@ -843,7 +831,7 @@ class Qdrant(VectorStore): - 'all' - query all replicas, and return values present in all replicas **kwargs: Any other named arguments to pass through to - QdrantClient.async_grpc_points.Search(). + AsyncQdrantClient.Search(). Returns: List of Documents selected by maximal marginal relevance. """ @@ -968,7 +956,7 @@ class Qdrant(VectorStore): - 'all' - query all replicas, and return values present in all replicas **kwargs: Any other named arguments to pass through to - QdrantClient.async_grpc_points.Search(). + AsyncQdrantClient.Search(). Returns: List of Documents selected by maximal marginal relevance and distance for each. @@ -1099,41 +1087,45 @@ class Qdrant(VectorStore): List of Documents selected by maximal marginal relevance and distance for each. """ - from qdrant_client.conversions.conversion import GrpcToRest + from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal - response = await self._asearch_with_score_by_vector( - embedding, - k=fetch_k, - filter=filter, + if self.async_client is None or isinstance( + self.async_client._client, AsyncQdrantLocal + ): + raise NotImplementedError( + "QdrantLocal cannot interoperate with sync and async clients" + ) + query_vector = embedding + if self.vector_name is not None: + query_vector = (self.vector_name, query_vector) # type: ignore[assignment] + + results = await self.async_client.search( + collection_name=self.collection_name, + query_vector=query_vector, + query_filter=filter, search_params=search_params, + limit=fetch_k, + with_payload=True, + with_vectors=True, score_threshold=score_threshold, consistency=consistency, - with_vectors=True, **kwargs, ) - results = [ - GrpcToRest.convert_vectors(result.vectors) for result in response.result - ] - embeddings: List[List[float]] = [ - result.get(self.vector_name) # type: ignore - if isinstance(result, dict) - else result + embeddings = [ + result.vector.get(self.vector_name) # type: ignore[index, union-attr] + if self.vector_name is not None + else result.vector for result in results ] - mmr_selected: List[int] = maximal_marginal_relevance( - np.array(embedding), - embeddings, - k=k, - lambda_mult=lambda_mult, + mmr_selected = maximal_marginal_relevance( + np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult ) return [ ( - self._document_from_scored_point_grpc( - response.result[i], - self.content_payload_key, - self.metadata_payload_key, + self._document_from_scored_point( + results[i], self.content_payload_key, self.metadata_payload_key ), - response.result[i].score, + results[i].score, ) for i in mmr_selected ] @@ -1543,7 +1535,7 @@ class Qdrant(VectorStore): **kwargs: Any, ) -> Qdrant: try: - import qdrant_client + import qdrant_client # noqa except ImportError: raise ValueError( "Could not import qdrant-client python package. " @@ -1558,7 +1550,7 @@ class Qdrant(VectorStore): vector_size = len(partial_embeddings[0]) collection_name = collection_name or uuid.uuid4().hex distance_func = distance_func.upper() - client = qdrant_client.QdrantClient( + client, async_client = cls._generate_clients( location=location, url=url, port=port, @@ -1669,6 +1661,7 @@ class Qdrant(VectorStore): metadata_payload_key=metadata_payload_key, distance_strategy=distance_func, vector_name=vector_name, + async_client=async_client, ) return qdrant @@ -1707,7 +1700,7 @@ class Qdrant(VectorStore): **kwargs: Any, ) -> Qdrant: try: - import qdrant_client + import qdrant_client # noqa except ImportError: raise ValueError( "Could not import qdrant-client python package. " @@ -1722,7 +1715,7 @@ class Qdrant(VectorStore): vector_size = len(partial_embeddings[0]) collection_name = collection_name or uuid.uuid4().hex distance_func = distance_func.upper() - client = qdrant_client.QdrantClient( + client, async_client = cls._generate_clients( location=location, url=url, port=port, @@ -1833,6 +1826,7 @@ class Qdrant(VectorStore): metadata_payload_key=metadata_payload_key, distance_strategy=distance_func, vector_name=vector_name, + async_client=async_client, ) return qdrant @@ -1922,21 +1916,6 @@ class Qdrant(VectorStore): metadata=scored_point.payload.get(metadata_payload_key) or {}, ) - @classmethod - def _document_from_scored_point_grpc( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - from qdrant_client.conversions.conversion import grpc_to_payload - - payload = grpc_to_payload(scored_point.payload) - return Document( - page_content=payload[content_payload_key], - metadata=payload.get(metadata_payload_key) or {}, - ) - def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: from qdrant_client.http import models as rest @@ -2134,3 +2113,57 @@ class Qdrant(VectorStore): ] yield batch_ids, points + + @staticmethod + def _generate_clients( + location: Optional[str] = None, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + path: Optional[str] = None, + **kwargs: Any, + ) -> Tuple[Any, Any]: + from qdrant_client import AsyncQdrantClient, QdrantClient + + sync_client = QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + **kwargs, + ) + + if location == ":memory:" or path is not None: + # Local Qdrant cannot co-exist with Sync and Async clients + # We fallback to sync operations in this case + async_client = None + else: + async_client = AsyncQdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + **kwargs, + ) + + return sync_client, async_client