refactor: Qdrant async improvements (#14492)

Follow up on https://github.com/langchain-ai/langchain/pull/13048.
This PR intends to simplify the Qdrant async implementation by replacing
the internal GRPC methods with the `QdrantAsyncClient` methods.
This is a backward compatible change with no additional steps required
after merge.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Anush
2024-01-03 09:37:48 +05:30
committed by GitHub
parent cda68d717c
commit 58cc7878e9

View File

@@ -22,11 +22,11 @@ from typing import (
) )
import numpy as np import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_community.docstore.document import Document
from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -94,6 +94,7 @@ class Qdrant(VectorStore):
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
distance_strategy: str = "COSINE", distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
async_client: Optional[Any] = None,
embedding_function: Optional[Callable] = None, # deprecated embedding_function: Optional[Callable] = None, # deprecated
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
@@ -111,6 +112,14 @@ class Qdrant(VectorStore):
f"got {type(client)}" f"got {type(client)}"
) )
if async_client is not None and not isinstance(
async_client, qdrant_client.AsyncQdrantClient
):
raise ValueError(
f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
f"got {type(async_client)}"
)
if embeddings is None and embedding_function is None: if embeddings is None and embedding_function is None:
raise ValueError( raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance." "`embeddings` value can't be None. Pass `Embeddings` instance."
@@ -125,6 +134,7 @@ class Qdrant(VectorStore):
self._embeddings = embeddings self._embeddings = embeddings
self._embeddings_function = embedding_function self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client self.client: qdrant_client.QdrantClient = client
self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client
self.collection_name = collection_name self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
@@ -208,18 +218,21 @@ class Qdrant(VectorStore):
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
from qdrant_client import grpc # noqa from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
from qdrant_client.conversions.conversion import RestToGrpc
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
added_ids = [] added_ids = []
async for batch_ids, points in self._agenerate_rest_batches( async for batch_ids, points in self._agenerate_rest_batches(
texts, metadatas, ids, batch_size texts, metadatas, ids, batch_size
): ):
await self.client.async_grpc_points.Upsert( await self.async_client.upsert(
grpc.UpsertPoints( collection_name=self.collection_name, points=points, **kwargs
collection_name=self.collection_name,
points=[RestToGrpc.convert_point_struct(point) for point in points],
)
) )
added_ids.extend(batch_ids) added_ids.extend(batch_ids)
@@ -399,7 +412,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
@@ -514,7 +527,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
@@ -614,56 +627,6 @@ class Qdrant(VectorStore):
for result in results for result in results
] ]
async def _asearch_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
with_vectors: bool = False,
**kwargs: Any,
) -> Any:
"""Return results most similar to embedding vector."""
from qdrant_client import grpc # noqa
from qdrant_client.conversions.conversion import RestToGrpc
from qdrant_client.http import models as rest
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
response = await self.client.async_grpc_points.Search(
grpc.SearchPoints(
collection_name=self.collection_name,
vector_name=self.vector_name,
vector=embedding,
filter=qdrant_filter,
params=search_params,
limit=k,
offset=offset,
with_payload=grpc.WithPayloadSelector(enable=True),
with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
score_threshold=score_threshold,
read_consistency=consistency,
**kwargs,
)
)
return response
@sync_call_fallback @sync_call_fallback
async def asimilarity_search_with_score_by_vector( async def asimilarity_search_with_score_by_vector(
self, self,
@@ -706,30 +669,55 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of documents most similar to the query text and distance for each. List of documents most similar to the query text and distance for each.
""" """
response = await self._asearch_with_score_by_vector( from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
embedding,
k=k, if self.async_client is None or isinstance(
filter=filter, self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=qdrant_filter,
search_params=search_params, search_params=search_params,
limit=k,
offset=offset, offset=offset,
with_payload=True,
with_vectors=False, # Langchain does not expect vectors to be returned
score_threshold=score_threshold, score_threshold=score_threshold,
consistency=consistency, consistency=consistency,
**kwargs, **kwargs,
) )
return [ return [
( (
self._document_from_scored_point_grpc( self._document_from_scored_point(
result, self.content_payload_key, self.metadata_payload_key result, self.content_payload_key, self.metadata_payload_key
), ),
result.score, result.score,
) )
for result in response.result for result in results
] ]
def max_marginal_relevance_search( def max_marginal_relevance_search(
@@ -843,7 +831,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
@@ -968,7 +956,7 @@ class Qdrant(VectorStore):
- 'all' - query all replicas, and return values present in all replicas - 'all' - query all replicas, and return values present in all replicas
**kwargs: **kwargs:
Any other named arguments to pass through to Any other named arguments to pass through to
QdrantClient.async_grpc_points.Search(). AsyncQdrantClient.Search().
Returns: Returns:
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
@@ -1099,41 +1087,45 @@ class Qdrant(VectorStore):
List of Documents selected by maximal marginal relevance and distance for List of Documents selected by maximal marginal relevance and distance for
each. each.
""" """
from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
response = await self._asearch_with_score_by_vector( if self.async_client is None or isinstance(
embedding, self.async_client._client, AsyncQdrantLocal
k=fetch_k, ):
filter=filter, raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=search_params, search_params=search_params,
limit=fetch_k,
with_payload=True,
with_vectors=True,
score_threshold=score_threshold, score_threshold=score_threshold,
consistency=consistency, consistency=consistency,
with_vectors=True,
**kwargs, **kwargs,
) )
results = [ embeddings = [
GrpcToRest.convert_vectors(result.vectors) for result in response.result result.vector.get(self.vector_name) # type: ignore[index, union-attr]
] if self.vector_name is not None
embeddings: List[List[float]] = [ else result.vector
result.get(self.vector_name) # type: ignore
if isinstance(result, dict)
else result
for result in results for result in results
] ]
mmr_selected: List[int] = maximal_marginal_relevance( mmr_selected = maximal_marginal_relevance(
np.array(embedding), np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
embeddings,
k=k,
lambda_mult=lambda_mult,
) )
return [ return [
( (
self._document_from_scored_point_grpc( self._document_from_scored_point(
response.result[i], results[i], self.content_payload_key, self.metadata_payload_key
self.content_payload_key,
self.metadata_payload_key,
), ),
response.result[i].score, results[i].score,
) )
for i in mmr_selected for i in mmr_selected
] ]
@@ -1543,7 +1535,7 @@ class Qdrant(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
try: try:
import qdrant_client import qdrant_client # noqa
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import qdrant-client python package. " "Could not import qdrant-client python package. "
@@ -1558,7 +1550,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0]) vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper() distance_func = distance_func.upper()
client = qdrant_client.QdrantClient( client, async_client = cls._generate_clients(
location=location, location=location,
url=url, url=url,
port=port, port=port,
@@ -1669,6 +1661,7 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func, distance_strategy=distance_func,
vector_name=vector_name, vector_name=vector_name,
async_client=async_client,
) )
return qdrant return qdrant
@@ -1707,7 +1700,7 @@ class Qdrant(VectorStore):
**kwargs: Any, **kwargs: Any,
) -> Qdrant: ) -> Qdrant:
try: try:
import qdrant_client import qdrant_client # noqa
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import qdrant-client python package. " "Could not import qdrant-client python package. "
@@ -1722,7 +1715,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0]) vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper() distance_func = distance_func.upper()
client = qdrant_client.QdrantClient( client, async_client = cls._generate_clients(
location=location, location=location,
url=url, url=url,
port=port, port=port,
@@ -1833,6 +1826,7 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func, distance_strategy=distance_func,
vector_name=vector_name, vector_name=vector_name,
async_client=async_client,
) )
return qdrant return qdrant
@@ -1922,21 +1916,6 @@ class Qdrant(VectorStore):
metadata=scored_point.payload.get(metadata_payload_key) or {}, metadata=scored_point.payload.get(metadata_payload_key) or {},
) )
@classmethod
def _document_from_scored_point_grpc(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
from qdrant_client.conversions.conversion import grpc_to_payload
payload = grpc_to_payload(scored_point.payload)
return Document(
page_content=payload[content_payload_key],
metadata=payload.get(metadata_payload_key) or {},
)
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
@@ -2134,3 +2113,57 @@ class Qdrant(VectorStore):
] ]
yield batch_ids, points yield batch_ids, points
@staticmethod
def _generate_clients(
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
**kwargs: Any,
) -> Tuple[Any, Any]:
from qdrant_client import AsyncQdrantClient, QdrantClient
sync_client = QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
if location == ":memory:" or path is not None:
# Local Qdrant cannot co-exist with Sync and Async clients
# We fallback to sync operations in this case
async_client = None
else:
async_client = AsyncQdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
return sync_client, async_client