mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
refactor: Qdrant async improvements (#14492)
Follow up on https://github.com/langchain-ai/langchain/pull/13048. This PR intends to simplify the Qdrant async implementation by replacing the internal GRPC methods with the `QdrantAsyncClient` methods. This is a backward compatible change with no additional steps required after merge. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -22,11 +22,11 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
|
from langchain_community.docstore.document import Document
|
||||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -94,6 +94,7 @@ class Qdrant(VectorStore):
|
|||||||
metadata_payload_key: str = METADATA_KEY,
|
metadata_payload_key: str = METADATA_KEY,
|
||||||
distance_strategy: str = "COSINE",
|
distance_strategy: str = "COSINE",
|
||||||
vector_name: Optional[str] = VECTOR_NAME,
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
|
async_client: Optional[Any] = None,
|
||||||
embedding_function: Optional[Callable] = None, # deprecated
|
embedding_function: Optional[Callable] = None, # deprecated
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
@@ -111,6 +112,14 @@ class Qdrant(VectorStore):
|
|||||||
f"got {type(client)}"
|
f"got {type(client)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if async_client is not None and not isinstance(
|
||||||
|
async_client, qdrant_client.AsyncQdrantClient
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
|
||||||
|
f"got {type(async_client)}"
|
||||||
|
)
|
||||||
|
|
||||||
if embeddings is None and embedding_function is None:
|
if embeddings is None and embedding_function is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
"`embeddings` value can't be None. Pass `Embeddings` instance."
|
||||||
@@ -125,6 +134,7 @@ class Qdrant(VectorStore):
|
|||||||
self._embeddings = embeddings
|
self._embeddings = embeddings
|
||||||
self._embeddings_function = embedding_function
|
self._embeddings_function = embedding_function
|
||||||
self.client: qdrant_client.QdrantClient = client
|
self.client: qdrant_client.QdrantClient = client
|
||||||
|
self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
self.content_payload_key = content_payload_key or self.CONTENT_KEY
|
||||||
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
|
||||||
@@ -208,18 +218,21 @@ class Qdrant(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of ids from adding the texts into the vectorstore.
|
List of ids from adding the texts into the vectorstore.
|
||||||
"""
|
"""
|
||||||
from qdrant_client import grpc # noqa
|
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||||
from qdrant_client.conversions.conversion import RestToGrpc
|
|
||||||
|
if self.async_client is None or isinstance(
|
||||||
|
self.async_client._client, AsyncQdrantLocal
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"QdrantLocal cannot interoperate with sync and async clients"
|
||||||
|
)
|
||||||
|
|
||||||
added_ids = []
|
added_ids = []
|
||||||
async for batch_ids, points in self._agenerate_rest_batches(
|
async for batch_ids, points in self._agenerate_rest_batches(
|
||||||
texts, metadatas, ids, batch_size
|
texts, metadatas, ids, batch_size
|
||||||
):
|
):
|
||||||
await self.client.async_grpc_points.Upsert(
|
await self.async_client.upsert(
|
||||||
grpc.UpsertPoints(
|
collection_name=self.collection_name, points=points, **kwargs
|
||||||
collection_name=self.collection_name,
|
|
||||||
points=[RestToGrpc.convert_point_struct(point) for point in points],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
added_ids.extend(batch_ids)
|
added_ids.extend(batch_ids)
|
||||||
|
|
||||||
@@ -399,7 +412,7 @@ class Qdrant(VectorStore):
|
|||||||
- 'all' - query all replicas, and return values present in all replicas
|
- 'all' - query all replicas, and return values present in all replicas
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Any other named arguments to pass through to
|
Any other named arguments to pass through to
|
||||||
QdrantClient.async_grpc_points.Search().
|
AsyncQdrantClient.Search().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of documents most similar to the query text and distance for each.
|
List of documents most similar to the query text and distance for each.
|
||||||
@@ -514,7 +527,7 @@ class Qdrant(VectorStore):
|
|||||||
- 'all' - query all replicas, and return values present in all replicas
|
- 'all' - query all replicas, and return values present in all replicas
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Any other named arguments to pass through to
|
Any other named arguments to pass through to
|
||||||
QdrantClient.async_grpc_points.Search().
|
AsyncQdrantClient.Search().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query.
|
List of Documents most similar to the query.
|
||||||
@@ -614,56 +627,6 @@ class Qdrant(VectorStore):
|
|||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
async def _asearch_with_score_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
*,
|
|
||||||
k: int = 4,
|
|
||||||
filter: Optional[MetadataFilter] = None,
|
|
||||||
search_params: Optional[common_types.SearchParams] = None,
|
|
||||||
offset: int = 0,
|
|
||||||
score_threshold: Optional[float] = None,
|
|
||||||
consistency: Optional[common_types.ReadConsistency] = None,
|
|
||||||
with_vectors: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""Return results most similar to embedding vector."""
|
|
||||||
from qdrant_client import grpc # noqa
|
|
||||||
from qdrant_client.conversions.conversion import RestToGrpc
|
|
||||||
from qdrant_client.http import models as rest
|
|
||||||
|
|
||||||
if filter is not None and isinstance(filter, dict):
|
|
||||||
warnings.warn(
|
|
||||||
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
|
||||||
"filters directly: "
|
|
||||||
"https://qdrant.tech/documentation/concepts/filtering/",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
|
||||||
else:
|
|
||||||
qdrant_filter = filter
|
|
||||||
|
|
||||||
if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
|
|
||||||
qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
|
|
||||||
|
|
||||||
response = await self.client.async_grpc_points.Search(
|
|
||||||
grpc.SearchPoints(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
vector_name=self.vector_name,
|
|
||||||
vector=embedding,
|
|
||||||
filter=qdrant_filter,
|
|
||||||
params=search_params,
|
|
||||||
limit=k,
|
|
||||||
offset=offset,
|
|
||||||
with_payload=grpc.WithPayloadSelector(enable=True),
|
|
||||||
with_vectors=grpc.WithVectorsSelector(enable=with_vectors),
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
read_consistency=consistency,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
@sync_call_fallback
|
@sync_call_fallback
|
||||||
async def asimilarity_search_with_score_by_vector(
|
async def asimilarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
@@ -706,30 +669,55 @@ class Qdrant(VectorStore):
|
|||||||
- 'all' - query all replicas, and return values present in all replicas
|
- 'all' - query all replicas, and return values present in all replicas
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Any other named arguments to pass through to
|
Any other named arguments to pass through to
|
||||||
QdrantClient.async_grpc_points.Search().
|
AsyncQdrantClient.Search().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of documents most similar to the query text and distance for each.
|
List of documents most similar to the query text and distance for each.
|
||||||
"""
|
"""
|
||||||
response = await self._asearch_with_score_by_vector(
|
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||||
embedding,
|
|
||||||
k=k,
|
if self.async_client is None or isinstance(
|
||||||
filter=filter,
|
self.async_client._client, AsyncQdrantLocal
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"QdrantLocal cannot interoperate with sync and async clients"
|
||||||
|
)
|
||||||
|
if filter is not None and isinstance(filter, dict):
|
||||||
|
warnings.warn(
|
||||||
|
"Using dict as a `filter` is deprecated. Please use qdrant-client "
|
||||||
|
"filters directly: "
|
||||||
|
"https://qdrant.tech/documentation/concepts/filtering/",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
qdrant_filter = self._qdrant_filter_from_dict(filter)
|
||||||
|
else:
|
||||||
|
qdrant_filter = filter
|
||||||
|
|
||||||
|
query_vector = embedding
|
||||||
|
if self.vector_name is not None:
|
||||||
|
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
|
||||||
|
|
||||||
|
results = await self.async_client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=qdrant_filter,
|
||||||
search_params=search_params,
|
search_params=search_params,
|
||||||
|
limit=k,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False, # Langchain does not expect vectors to be returned
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
consistency=consistency,
|
consistency=consistency,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
self._document_from_scored_point_grpc(
|
self._document_from_scored_point(
|
||||||
result, self.content_payload_key, self.metadata_payload_key
|
result, self.content_payload_key, self.metadata_payload_key
|
||||||
),
|
),
|
||||||
result.score,
|
result.score,
|
||||||
)
|
)
|
||||||
for result in response.result
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
@@ -843,7 +831,7 @@ class Qdrant(VectorStore):
|
|||||||
- 'all' - query all replicas, and return values present in all replicas
|
- 'all' - query all replicas, and return values present in all replicas
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Any other named arguments to pass through to
|
Any other named arguments to pass through to
|
||||||
QdrantClient.async_grpc_points.Search().
|
AsyncQdrantClient.Search().
|
||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
@@ -968,7 +956,7 @@ class Qdrant(VectorStore):
|
|||||||
- 'all' - query all replicas, and return values present in all replicas
|
- 'all' - query all replicas, and return values present in all replicas
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Any other named arguments to pass through to
|
Any other named arguments to pass through to
|
||||||
QdrantClient.async_grpc_points.Search().
|
AsyncQdrantClient.Search().
|
||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance and distance for
|
List of Documents selected by maximal marginal relevance and distance for
|
||||||
each.
|
each.
|
||||||
@@ -1099,41 +1087,45 @@ class Qdrant(VectorStore):
|
|||||||
List of Documents selected by maximal marginal relevance and distance for
|
List of Documents selected by maximal marginal relevance and distance for
|
||||||
each.
|
each.
|
||||||
"""
|
"""
|
||||||
from qdrant_client.conversions.conversion import GrpcToRest
|
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
|
||||||
|
|
||||||
response = await self._asearch_with_score_by_vector(
|
if self.async_client is None or isinstance(
|
||||||
embedding,
|
self.async_client._client, AsyncQdrantLocal
|
||||||
k=fetch_k,
|
):
|
||||||
filter=filter,
|
raise NotImplementedError(
|
||||||
|
"QdrantLocal cannot interoperate with sync and async clients"
|
||||||
|
)
|
||||||
|
query_vector = embedding
|
||||||
|
if self.vector_name is not None:
|
||||||
|
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
|
||||||
|
|
||||||
|
results = await self.async_client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=filter,
|
||||||
search_params=search_params,
|
search_params=search_params,
|
||||||
|
limit=fetch_k,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=True,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
consistency=consistency,
|
consistency=consistency,
|
||||||
with_vectors=True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
results = [
|
embeddings = [
|
||||||
GrpcToRest.convert_vectors(result.vectors) for result in response.result
|
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
|
||||||
]
|
if self.vector_name is not None
|
||||||
embeddings: List[List[float]] = [
|
else result.vector
|
||||||
result.get(self.vector_name) # type: ignore
|
|
||||||
if isinstance(result, dict)
|
|
||||||
else result
|
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
mmr_selected: List[int] = maximal_marginal_relevance(
|
mmr_selected = maximal_marginal_relevance(
|
||||||
np.array(embedding),
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||||
embeddings,
|
|
||||||
k=k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
self._document_from_scored_point_grpc(
|
self._document_from_scored_point(
|
||||||
response.result[i],
|
results[i], self.content_payload_key, self.metadata_payload_key
|
||||||
self.content_payload_key,
|
|
||||||
self.metadata_payload_key,
|
|
||||||
),
|
),
|
||||||
response.result[i].score,
|
results[i].score,
|
||||||
)
|
)
|
||||||
for i in mmr_selected
|
for i in mmr_selected
|
||||||
]
|
]
|
||||||
@@ -1543,7 +1535,7 @@ class Qdrant(VectorStore):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Qdrant:
|
) -> Qdrant:
|
||||||
try:
|
try:
|
||||||
import qdrant_client
|
import qdrant_client # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import qdrant-client python package. "
|
"Could not import qdrant-client python package. "
|
||||||
@@ -1558,7 +1550,7 @@ class Qdrant(VectorStore):
|
|||||||
vector_size = len(partial_embeddings[0])
|
vector_size = len(partial_embeddings[0])
|
||||||
collection_name = collection_name or uuid.uuid4().hex
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
distance_func = distance_func.upper()
|
distance_func = distance_func.upper()
|
||||||
client = qdrant_client.QdrantClient(
|
client, async_client = cls._generate_clients(
|
||||||
location=location,
|
location=location,
|
||||||
url=url,
|
url=url,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -1669,6 +1661,7 @@ class Qdrant(VectorStore):
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
distance_strategy=distance_func,
|
distance_strategy=distance_func,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
async_client=async_client,
|
||||||
)
|
)
|
||||||
return qdrant
|
return qdrant
|
||||||
|
|
||||||
@@ -1707,7 +1700,7 @@ class Qdrant(VectorStore):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Qdrant:
|
) -> Qdrant:
|
||||||
try:
|
try:
|
||||||
import qdrant_client
|
import qdrant_client # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import qdrant-client python package. "
|
"Could not import qdrant-client python package. "
|
||||||
@@ -1722,7 +1715,7 @@ class Qdrant(VectorStore):
|
|||||||
vector_size = len(partial_embeddings[0])
|
vector_size = len(partial_embeddings[0])
|
||||||
collection_name = collection_name or uuid.uuid4().hex
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
distance_func = distance_func.upper()
|
distance_func = distance_func.upper()
|
||||||
client = qdrant_client.QdrantClient(
|
client, async_client = cls._generate_clients(
|
||||||
location=location,
|
location=location,
|
||||||
url=url,
|
url=url,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -1833,6 +1826,7 @@ class Qdrant(VectorStore):
|
|||||||
metadata_payload_key=metadata_payload_key,
|
metadata_payload_key=metadata_payload_key,
|
||||||
distance_strategy=distance_func,
|
distance_strategy=distance_func,
|
||||||
vector_name=vector_name,
|
vector_name=vector_name,
|
||||||
|
async_client=async_client,
|
||||||
)
|
)
|
||||||
return qdrant
|
return qdrant
|
||||||
|
|
||||||
@@ -1922,21 +1916,6 @@ class Qdrant(VectorStore):
|
|||||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _document_from_scored_point_grpc(
|
|
||||||
cls,
|
|
||||||
scored_point: Any,
|
|
||||||
content_payload_key: str,
|
|
||||||
metadata_payload_key: str,
|
|
||||||
) -> Document:
|
|
||||||
from qdrant_client.conversions.conversion import grpc_to_payload
|
|
||||||
|
|
||||||
payload = grpc_to_payload(scored_point.payload)
|
|
||||||
return Document(
|
|
||||||
page_content=payload[content_payload_key],
|
|
||||||
metadata=payload.get(metadata_payload_key) or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
@@ -2134,3 +2113,57 @@ class Qdrant(VectorStore):
|
|||||||
]
|
]
|
||||||
|
|
||||||
yield batch_ids, points
|
yield batch_ids, points
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_clients(
|
||||||
|
location: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
port: Optional[int] = 6333,
|
||||||
|
grpc_port: int = 6334,
|
||||||
|
prefer_grpc: bool = False,
|
||||||
|
https: Optional[bool] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
|
from qdrant_client import AsyncQdrantClient, QdrantClient
|
||||||
|
|
||||||
|
sync_client = QdrantClient(
|
||||||
|
location=location,
|
||||||
|
url=url,
|
||||||
|
port=port,
|
||||||
|
grpc_port=grpc_port,
|
||||||
|
prefer_grpc=prefer_grpc,
|
||||||
|
https=https,
|
||||||
|
api_key=api_key,
|
||||||
|
prefix=prefix,
|
||||||
|
timeout=timeout,
|
||||||
|
host=host,
|
||||||
|
path=path,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if location == ":memory:" or path is not None:
|
||||||
|
# Local Qdrant cannot co-exist with Sync and Async clients
|
||||||
|
# We fallback to sync operations in this case
|
||||||
|
async_client = None
|
||||||
|
else:
|
||||||
|
async_client = AsyncQdrantClient(
|
||||||
|
location=location,
|
||||||
|
url=url,
|
||||||
|
port=port,
|
||||||
|
grpc_port=grpc_port,
|
||||||
|
prefer_grpc=prefer_grpc,
|
||||||
|
https=https,
|
||||||
|
api_key=api_key,
|
||||||
|
prefix=prefix,
|
||||||
|
timeout=timeout,
|
||||||
|
host=host,
|
||||||
|
path=path,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_client, async_client
|
||||||
|
Reference in New Issue
Block a user