mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
refactor: Qdrant async improvements (#14492)
Follow up on https://github.com/langchain-ai/langchain/pull/13048. This PR intends to simplify the Qdrant async implementation by replacing the internal GRPC methods with the `QdrantAsyncClient` methods. This is a backward compatible change with no additional steps required after merge. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -22,11 +22,11 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.docstore.document import Document
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -94,6 +94,7 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
distance_strategy: str = "COSINE",
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
async_client: Optional[Any] = None,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
@@ -111,6 +112,14 @@ class Qdrant(VectorStore):
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
if async_client is not None and not isinstance(
|
||||
async_client, qdrant_client.AsyncQdrantClient
|
||||
):
|
||||
raise ValueError(
|
||||
f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
|
||||
f"got {type(async_client)}"
|
||||
)
|
||||
|
||||
if embeddings is None and embedding_function is None:
|
||||
raise ValueError(
|
||||
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
||||
@@ -125,6 +134,7 @@ class Qdrant(VectorStore):
|
||||
self._embeddings = embeddings
|
||||
self._embeddings_function = embedding_function
|
||||
self.client: qdrant_client.QdrantClient = client
|
||||
self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client
|
||||
self.collection_name = collection_name
|
||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||
@@ -208,18 +218,21 @@ class Qdrant(VectorStore):
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions.conversion import RestToGrpc
|
||||
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||
|
||||
if self.async_client is None or isinstance(
|
||||
self.async_client._client, AsyncQdrantLocal
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"QdrantLocal cannot interoperate with sync and async clients"
|
||||
)
|
||||
|
||||
added_ids = []
|
||||
async for batch_ids, points in self._agenerate_rest_batches(
|
||||
texts, metadatas, ids, batch_size
|
||||
):
|
||||
await self.client.async_grpc_points.Upsert(
|
||||
grpc.UpsertPoints(
|
||||
collection_name=self.collection_name,
|
||||
points=[RestToGrpc.convert_point_struct(point) for point in points],
|
||||
)
|
||||
await self.async_client.upsert(
|
||||
collection_name=self.collection_name, points=points, **kwargs
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
@@ -399,7 +412,7 @@ class Qdrant(VectorStore):
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
AsyncQdrantClient.Search().
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
@@ -514,7 +527,7 @@ class Qdrant(VectorStore):
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
AsyncQdrantClient.Search().
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
@@ -614,56 +627,6 @@ class Qdrant(VectorStore):
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def _asearch_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
search_params: Optional[common_types.SearchParams] = None,
|
||||
offset: int = 0,
|
||||
score_threshold: Optional[float] = None,
|
||||
consistency: Optional[common_types.ReadConsistency] = None,
|
||||
with_vectors: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Return results most similar to embedding vector."""
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions.conversion import RestToGrpc
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
if filter is not None and isinstance(filter, dict):
|
||||
warnings.warn(
|
||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||
"filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/",
|
||||
DeprecationWarning,
|
||||
)
|
||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||
else:
|
||||
qdrant_filter = filter
|
||||
|
||||
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
|
||||
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
|
||||
|
||||
response = await self.client.async_grpc_points.Search(
|
||||
grpc.SearchPoints(
|
||||
collection_name=self.collection_name,
|
||||
vector_name=self.vector_name,
|
||||
vector=embedding,
|
||||
filter=qdrant_filter,
|
||||
params=search_params,
|
||||
limit=k,
|
||||
offset=offset,
|
||||
with_payload=grpc.WithPayloadSelector(enable=True),
|
||||
with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
|
||||
score_threshold=score_threshold,
|
||||
read_consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_with_score_by_vector(
|
||||
self,
|
||||
@@ -706,30 +669,55 @@ class Qdrant(VectorStore):
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
AsyncQdrantClient.Search().
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
response = await self._asearch_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||
|
||||
if self.async_client is None or isinstance(
|
||||
self.async_client._client, AsyncQdrantLocal
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"QdrantLocal cannot interoperate with sync and async clients"
|
||||
)
|
||||
if filter is not None and isinstance(filter, dict):
|
||||
warnings.warn(
|
||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||
"filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/",
|
||||
DeprecationWarning,
|
||||
)
|
||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||
else:
|
||||
qdrant_filter = filter
|
||||
|
||||
query_vector = embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
|
||||
|
||||
results = await self.async_client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=qdrant_filter,
|
||||
search_params=search_params,
|
||||
limit=k,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False, # Langchain does not expect vectors to be returned
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [
|
||||
(
|
||||
self._document_from_scored_point_grpc(
|
||||
self._document_from_scored_point(
|
||||
result, self.content_payload_key, self.metadata_payload_key
|
||||
),
|
||||
result.score,
|
||||
)
|
||||
for result in response.result
|
||||
for result in results
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
@@ -843,7 +831,7 @@ class Qdrant(VectorStore):
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
AsyncQdrantClient.Search().
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
@@ -968,7 +956,7 @@ class Qdrant(VectorStore):
|
||||
- 'all' - query all replicas, and return values present in all replicas
|
||||
**kwargs:
|
||||
Any other named arguments to pass through to
|
||||
QdrantClient.async_grpc_points.Search().
|
||||
AsyncQdrantClient.Search().
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
@@ -1099,41 +1087,45 @@ class Qdrant(VectorStore):
|
||||
List of Documents selected by maximal marginal relevance and distance for
|
||||
each.
|
||||
"""
|
||||
from qdrant_client.conversions.conversion import GrpcToRest
|
||||
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||
|
||||
response = await self._asearch_with_score_by_vector(
|
||||
embedding,
|
||||
k=fetch_k,
|
||||
filter=filter,
|
||||
if self.async_client is None or isinstance(
|
||||
self.async_client._client, AsyncQdrantLocal
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"QdrantLocal cannot interoperate with sync and async clients"
|
||||
)
|
||||
query_vector = embedding
|
||||
if self.vector_name is not None:
|
||||
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||
|
||||
results = await self.async_client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
search_params=search_params,
|
||||
limit=fetch_k,
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
with_vectors=True,
|
||||
**kwargs,
|
||||
)
|
||||
results = [
|
||||
GrpcToRest.convert_vectors(result.vectors) for result in response.result
|
||||
]
|
||||
embeddings: List[List[float]] = [
|
||||
result.get(self.vector_name) # type: ignore
|
||||
if isinstance(result, dict)
|
||||
else result
|
||||
embeddings = [
|
||||
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||
if self.vector_name is not None
|
||||
else result.vector
|
||||
for result in results
|
||||
]
|
||||
mmr_selected: List[int] = maximal_marginal_relevance(
|
||||
np.array(embedding),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
return [
|
||||
(
|
||||
self._document_from_scored_point_grpc(
|
||||
response.result[i],
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
self._document_from_scored_point(
|
||||
results[i], self.content_payload_key, self.metadata_payload_key
|
||||
),
|
||||
response.result[i].score,
|
||||
results[i].score,
|
||||
)
|
||||
for i in mmr_selected
|
||||
]
|
||||
@@ -1543,7 +1535,7 @@ class Qdrant(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
try:
|
||||
import qdrant_client
|
||||
import qdrant_client # noqa
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import qdrant-client python package. "
|
||||
@@ -1558,7 +1550,7 @@ class Qdrant(VectorStore):
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
client = qdrant_client.QdrantClient(
|
||||
client, async_client = cls._generate_clients(
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
@@ -1669,6 +1661,7 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
async_client=async_client,
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@@ -1707,7 +1700,7 @@ class Qdrant(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
try:
|
||||
import qdrant_client
|
||||
import qdrant_client # noqa
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import qdrant-client python package. "
|
||||
@@ -1722,7 +1715,7 @@ class Qdrant(VectorStore):
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
client = qdrant_client.QdrantClient(
|
||||
client, async_client = cls._generate_clients(
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
@@ -1833,6 +1826,7 @@ class Qdrant(VectorStore):
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
async_client=async_client,
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@@ -1922,21 +1916,6 @@ class Qdrant(VectorStore):
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point_grpc(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
from qdrant_client.conversions.conversion import grpc_to_payload
|
||||
|
||||
payload = grpc_to_payload(scored_point.payload)
|
||||
return Document(
|
||||
page_content=payload[content_payload_key],
|
||||
metadata=payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
@@ -2134,3 +2113,57 @@ class Qdrant(VectorStore):
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
@staticmethod
|
||||
def _generate_clients(
|
||||
location: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
port: Optional[int] = 6333,
|
||||
grpc_port: int = 6334,
|
||||
prefer_grpc: bool = False,
|
||||
https: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
host: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Any, Any]:
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient
|
||||
|
||||
sync_client = QdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
prefer_grpc=prefer_grpc,
|
||||
https=https,
|
||||
api_key=api_key,
|
||||
prefix=prefix,
|
||||
timeout=timeout,
|
||||
host=host,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if location == ":memory:" or path is not None:
|
||||
# Local Qdrant cannot co-exist with Sync and Async clients
|
||||
# We fallback to sync operations in this case
|
||||
async_client = None
|
||||
else:
|
||||
async_client = AsyncQdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
prefer_grpc=prefer_grpc,
|
||||
https=https,
|
||||
api_key=api_key,
|
||||
prefix=prefix,
|
||||
timeout=timeout,
|
||||
host=host,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return sync_client, async_client
|
||||
|
Reference in New Issue
Block a user