diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index 44d8f9a6cd3..f18b2cc9124 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -10,6 +10,7 @@ from operator import itemgetter from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Callable, Dict, Generator, @@ -213,7 +214,7 @@ class Qdrant(VectorStore): from qdrant_client.conversions.conversion import RestToGrpc added_ids = [] - for batch_ids, points in self._generate_rest_batches( + async for batch_ids, points in self._agenerate_rest_batches( texts, metadatas, ids, batch_size ): await self.client.async_grpc_points.Upsert( @@ -1264,7 +1265,7 @@ class Qdrant(VectorStore): embeddings = OpenAIEmbeddings() qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost") """ - qdrant = cls._construct_instance( + qdrant = await cls._aconstruct_instance( texts, embedding, location, @@ -1465,6 +1466,172 @@ class Qdrant(VectorStore): ) return qdrant + @classmethod + async def _aconstruct_instance( + cls: Type[Qdrant], + texts: List[str], + embedding: Embeddings, + location: Optional[str] = None, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + path: Optional[str] = None, + collection_name: Optional[str] = None, + distance_func: str = "Cosine", + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, + vector_name: Optional[str] = VECTOR_NAME, + shard_number: Optional[int] = None, + replication_factor: Optional[int] = None, + write_consistency_factor: Optional[int] = None, + on_disk_payload: Optional[bool] = None, + hnsw_config: Optional[common_types.HnswConfigDiff] = None, + optimizers_config: Optional[common_types.OptimizersConfigDiff] = None, + wal_config: Optional[common_types.WalConfigDiff] = None, + quantization_config: Optional[common_types.QuantizationConfig] = None, + init_from: Optional[common_types.InitFrom] = None, + on_disk: Optional[bool] = None, + force_recreate: bool = False, + **kwargs: Any, + ) -> Qdrant: + try: + import qdrant_client + except ImportError: + raise ValueError( + "Could not import qdrant-client python package. " + "Please install it with `pip install qdrant-client`." + ) + from grpc import RpcError + from qdrant_client.http import models as rest + from qdrant_client.http.exceptions import UnexpectedResponse + + # Just do a single quick embedding to get vector size + partial_embeddings = await embedding.aembed_documents(texts[:1]) + vector_size = len(partial_embeddings[0]) + collection_name = collection_name or uuid.uuid4().hex + distance_func = distance_func.upper() + client = qdrant_client.QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + **kwargs, + ) + try: + # Skip any validation in case of forced collection recreate. + if force_recreate: + raise ValueError + + # Get the vector configuration of the existing collection and vector, if it + # was specified. If the old configuration does not match the current one, + # an exception is being thrown. + collection_info = client.get_collection(collection_name=collection_name) + current_vector_config = collection_info.config.params.vectors + if isinstance(current_vector_config, dict) and vector_name is not None: + if vector_name not in current_vector_config: + raise QdrantException( + f"Existing Qdrant collection {collection_name} does not " + f"contain vector named {vector_name}. Did you mean one of the " + f"existing vectors: {', '.join(current_vector_config.keys())}? " + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + current_vector_config = current_vector_config.get( + vector_name + ) # type: ignore[assignment] + elif isinstance(current_vector_config, dict) and vector_name is None: + raise QdrantException( + f"Existing Qdrant collection {collection_name} uses named vectors. " + f"If you want to reuse it, please set `vector_name` to any of the " + f"existing named vectors: " + f"{', '.join(current_vector_config.keys())}." # noqa + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + elif ( + not isinstance(current_vector_config, dict) and vector_name is not None + ): + raise QdrantException( + f"Existing Qdrant collection {collection_name} doesn't use named " + f"vectors. If you want to reuse it, please set `vector_name` to " + f"`None`. If you want to recreate the collection, set " + f"`force_recreate` parameter to `True`." + ) + + # Check if the vector configuration has the same dimensionality. + if current_vector_config.size != vector_size: # type: ignore[union-attr] + raise QdrantException( + f"Existing Qdrant collection is configured for vectors with " + f"{current_vector_config.size} " # type: ignore[union-attr] + f"dimensions. Selected embeddings are {vector_size}-dimensional. " + f"If you want to recreate the collection, set `force_recreate` " + f"parameter to `True`." + ) + + current_distance_func = ( + current_vector_config.distance.name.upper() # type: ignore[union-attr] + ) + if current_distance_func != distance_func: + raise QdrantException( + f"Existing Qdrant collection is configured for " + f"{current_vector_config.distance} " # type: ignore[union-attr] + f"similarity. Please set `distance_func` parameter to " + f"`{distance_func}` if you want to reuse it. If you want to " + f"recreate the collection, set `force_recreate` parameter to " + f"`True`." + ) + except (UnexpectedResponse, RpcError, ValueError): + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[distance_func], + on_disk=on_disk, + ) + + # If vector name was provided, we're going to use the named vectors feature + # with just a single vector. + if vector_name is not None: + vectors_config = { # type: ignore[assignment] + vector_name: vectors_config, + } + + client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + shard_number=shard_number, + replication_factor=replication_factor, + write_consistency_factor=write_consistency_factor, + on_disk_payload=on_disk_payload, + hnsw_config=hnsw_config, + optimizers_config=optimizers_config, + wal_config=wal_config, + quantization_config=quantization_config, + init_from=init_from, + timeout=timeout, # type: ignore[arg-type] + ) + qdrant = cls( + client=client, + collection_name=collection_name, + embeddings=embedding, + content_payload_key=content_payload_key, + metadata_payload_key=metadata_payload_key, + distance_strategy=distance_func, + vector_name=vector_name, + ) + return qdrant + def _select_relevance_score_fn(self) -> Callable[[float], float]: """ The 'correct' relevance function @@ -1648,6 +1815,33 @@ class Qdrant(VectorStore): return embeddings + async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]: + """Embed search texts. + + Used to provide backward compatibility with `embedding_function` argument. + + Args: + texts: Iterable of texts to embed. + + Returns: + List of floats representing the texts embedding. + """ + if self.embeddings is not None: + embeddings = await self.embeddings.aembed_documents(list(texts)) + if hasattr(embeddings, "tolist"): + embeddings = embeddings.tolist() + elif self._embeddings_function is not None: + embeddings = [] + for text in texts: + embedding = self._embeddings_function(text) + if hasattr(embeddings, "tolist"): + embedding = embedding.tolist() + embeddings.append(embedding) + else: + raise ValueError("Neither of embeddings or embedding_function is set") + + return embeddings + def _generate_rest_batches( self, texts: Iterable[str], @@ -1689,3 +1883,45 @@ class Qdrant(VectorStore): ] yield batch_ids, points + + async def _agenerate_rest_batches( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + ) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]: + from qdrant_client.http import models as rest + + texts_iterator = iter(texts) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = await self._aembed_texts(batch_texts) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector + if self.vector_name is None + else {self.vector_name: vector}, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + self.content_payload_key, + self.metadata_payload_key, + ), + ) + ] + + yield batch_ids, points diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py index 1b299bba18d..550174e2e52 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py @@ -15,6 +15,9 @@ class FakeEmbeddings(Embeddings): Embeddings encode each text as its index.""" return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + def embed_query(self, text: str) -> List[float]: """Return constant query embeddings. Embeddings are identical to embed_documents(texts)[0]. @@ -22,6 +25,9 @@ class FakeEmbeddings(Embeddings): as it was passed to embed_documents.""" return [float(1.0)] * 9 + [float(0.0)] + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + class ConsistentFakeEmbeddings(FakeEmbeddings): """Fake embeddings which remember all the texts seen so far to return consistent