mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
Enhance qdrant vs using async embed documents (#9462)
This is an extension of #8104. I updated some of the signatures so all the tests pass. @danhnn I couldn't commit to your PR, so I created a new one. Thanks for your contribution! @baskaryan Could you please merge it? --------- Co-authored-by: Danh Nguyen <dnncntt@gmail.com>
This commit is contained in:
parent
83d2a871eb
commit
616e728ef9
@ -10,6 +10,7 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
@ -213,7 +214,7 @@ class Qdrant(VectorStore):
|
||||
from qdrant_client.conversions.conversion import RestToGrpc
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
async for batch_ids, points in self._agenerate_rest_batches(
|
||||
texts, metadatas, ids, batch_size
|
||||
):
|
||||
await self.client.async_grpc_points.Upsert(
|
||||
@ -1264,7 +1265,7 @@ class Qdrant(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
|
||||
"""
|
||||
qdrant = cls._construct_instance(
|
||||
qdrant = await cls._aconstruct_instance(
|
||||
texts,
|
||||
embedding,
|
||||
location,
|
||||
@ -1465,6 +1466,172 @@ class Qdrant(VectorStore):
|
||||
)
|
||||
return qdrant
|
||||
|
||||
@classmethod
|
||||
async def _aconstruct_instance(
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
location: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
port: Optional[int] = 6333,
|
||||
grpc_port: int = 6334,
|
||||
prefer_grpc: bool = False,
|
||||
https: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
host: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
distance_func: str = "Cosine",
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
shard_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
write_consistency_factor: Optional[int] = None,
|
||||
on_disk_payload: Optional[bool] = None,
|
||||
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
|
||||
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
|
||||
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||
init_from: Optional[common_types.InitFrom] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
force_recreate: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Qdrant:
|
||||
try:
|
||||
import qdrant_client
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import qdrant-client python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
)
|
||||
from grpc import RpcError
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
# Just do a single quick embedding to get vector size
|
||||
partial_embeddings = await embedding.aembed_documents(texts[:1])
|
||||
vector_size = len(partial_embeddings[0])
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
distance_func = distance_func.upper()
|
||||
client = qdrant_client.QdrantClient(
|
||||
location=location,
|
||||
url=url,
|
||||
port=port,
|
||||
grpc_port=grpc_port,
|
||||
prefer_grpc=prefer_grpc,
|
||||
https=https,
|
||||
api_key=api_key,
|
||||
prefix=prefix,
|
||||
timeout=timeout,
|
||||
host=host,
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Skip any validation in case of forced collection recreate.
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
on_disk=on_disk,
|
||||
)
|
||||
|
||||
# If vector name was provided, we're going to use the named vectors feature
|
||||
# with just a single vector.
|
||||
if vector_name is not None:
|
||||
vectors_config = { # type: ignore[assignment]
|
||||
vector_name: vectors_config,
|
||||
}
|
||||
|
||||
client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
shard_number=shard_number,
|
||||
replication_factor=replication_factor,
|
||||
write_consistency_factor=write_consistency_factor,
|
||||
on_disk_payload=on_disk_payload,
|
||||
hnsw_config=hnsw_config,
|
||||
optimizers_config=optimizers_config,
|
||||
wal_config=wal_config,
|
||||
quantization_config=quantization_config,
|
||||
init_from=init_from,
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embeddings=embedding,
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
return qdrant
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
@ -1648,6 +1815,33 @@ class Qdrant(VectorStore):
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
||||
"""Embed search texts.
|
||||
|
||||
Used to provide backward compatibility with `embedding_function` argument.
|
||||
|
||||
Args:
|
||||
texts: Iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of floats representing the texts embedding.
|
||||
"""
|
||||
if self.embeddings is not None:
|
||||
embeddings = await self.embeddings.aembed_documents(list(texts))
|
||||
if hasattr(embeddings, "tolist"):
|
||||
embeddings = embeddings.tolist()
|
||||
elif self._embeddings_function is not None:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = self._embeddings_function(text)
|
||||
if hasattr(embeddings, "tolist"):
|
||||
embedding = embedding.tolist()
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
raise ValueError("Neither of embeddings or embedding_function is set")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -1689,3 +1883,45 @@ class Qdrant(VectorStore):
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
async def _agenerate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = await self._aembed_texts(batch_texts)
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector
|
||||
if self.vector_name is None
|
||||
else {self.vector_name: vector},
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
@ -15,6 +15,9 @@ class FakeEmbeddings(Embeddings):
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
@ -22,6 +25,9 @@ class FakeEmbeddings(Embeddings):
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||
|
Loading…
Reference in New Issue
Block a user