mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
Enhance qdrant vs using async embed documents (#9462)
This is an extension of #8104. I updated some of the signatures so all the tests pass. @danhnn I couldn't commit to your PR, so I created a new one. Thanks for your contribution! @baskaryan Could you please merge it? --------- Co-authored-by: Danh Nguyen <dnncntt@gmail.com>
This commit is contained in:
parent
83d2a871eb
commit
616e728ef9
@ -10,6 +10,7 @@ from operator import itemgetter
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
@ -213,7 +214,7 @@ class Qdrant(VectorStore):
|
|||||||
from qdrant_client.conversions.conversion import RestToGrpc
|
from qdrant_client.conversions.conversion import RestToGrpc
|
||||||
|
|
||||||
added_ids = []
|
added_ids = []
|
||||||
for batch_ids, points in self._generate_rest_batches(
|
async for batch_ids, points in self._agenerate_rest_batches(
|
||||||
texts, metadatas, ids, batch_size
|
texts, metadatas, ids, batch_size
|
||||||
):
|
):
|
||||||
await self.client.async_grpc_points.Upsert(
|
await self.client.async_grpc_points.Upsert(
|
||||||
@ -1264,7 +1265,7 @@ class Qdrant(VectorStore):
|
|||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
|
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
|
||||||
"""
|
"""
|
||||||
qdrant = cls._construct_instance(
|
qdrant = await cls._aconstruct_instance(
|
||||||
texts,
|
texts,
|
||||||
embedding,
|
embedding,
|
||||||
location,
|
location,
|
||||||
@ -1465,6 +1466,172 @@ class Qdrant(VectorStore):
|
|||||||
)
|
)
|
||||||
return qdrant
|
return qdrant
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _aconstruct_instance(
|
||||||
|
cls: Type[Qdrant],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
location: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
port: Optional[int] = 6333,
|
||||||
|
grpc_port: int = 6334,
|
||||||
|
prefer_grpc: bool = False,
|
||||||
|
https: Optional[bool] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
collection_name: Optional[str] = None,
|
||||||
|
distance_func: str = "Cosine",
|
||||||
|
content_payload_key: str = CONTENT_KEY,
|
||||||
|
metadata_payload_key: str = METADATA_KEY,
|
||||||
|
vector_name: Optional[str] = VECTOR_NAME,
|
||||||
|
shard_number: Optional[int] = None,
|
||||||
|
replication_factor: Optional[int] = None,
|
||||||
|
write_consistency_factor: Optional[int] = None,
|
||||||
|
on_disk_payload: Optional[bool] = None,
|
||||||
|
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
|
||||||
|
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
|
||||||
|
wal_config: Optional[common_types.WalConfigDiff] = None,
|
||||||
|
quantization_config: Optional[common_types.QuantizationConfig] = None,
|
||||||
|
init_from: Optional[common_types.InitFrom] = None,
|
||||||
|
on_disk: Optional[bool] = None,
|
||||||
|
force_recreate: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Qdrant:
|
||||||
|
try:
|
||||||
|
import qdrant_client
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import qdrant-client python package. "
|
||||||
|
"Please install it with `pip install qdrant-client`."
|
||||||
|
)
|
||||||
|
from grpc import RpcError
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
|
# Just do a single quick embedding to get vector size
|
||||||
|
partial_embeddings = await embedding.aembed_documents(texts[:1])
|
||||||
|
vector_size = len(partial_embeddings[0])
|
||||||
|
collection_name = collection_name or uuid.uuid4().hex
|
||||||
|
distance_func = distance_func.upper()
|
||||||
|
client = qdrant_client.QdrantClient(
|
||||||
|
location=location,
|
||||||
|
url=url,
|
||||||
|
port=port,
|
||||||
|
grpc_port=grpc_port,
|
||||||
|
prefer_grpc=prefer_grpc,
|
||||||
|
https=https,
|
||||||
|
api_key=api_key,
|
||||||
|
prefix=prefix,
|
||||||
|
timeout=timeout,
|
||||||
|
host=host,
|
||||||
|
path=path,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Skip any validation in case of forced collection recreate.
|
||||||
|
if force_recreate:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
# Get the vector configuration of the existing collection and vector, if it
|
||||||
|
# was specified. If the old configuration does not match the current one,
|
||||||
|
# an exception is being thrown.
|
||||||
|
collection_info = client.get_collection(collection_name=collection_name)
|
||||||
|
current_vector_config = collection_info.config.params.vectors
|
||||||
|
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||||
|
if vector_name not in current_vector_config:
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} does not "
|
||||||
|
f"contain vector named {vector_name}. Did you mean one of the "
|
||||||
|
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||||
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
|
f"parameter to `True`."
|
||||||
|
)
|
||||||
|
current_vector_config = current_vector_config.get(
|
||||||
|
vector_name
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||||
|
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||||
|
f"existing named vectors: "
|
||||||
|
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||||
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
|
f"parameter to `True`."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||||
|
):
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||||
|
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||||
|
f"`None`. If you want to recreate the collection, set "
|
||||||
|
f"`force_recreate` parameter to `True`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the vector configuration has the same dimensionality.
|
||||||
|
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection is configured for vectors with "
|
||||||
|
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||||
|
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||||
|
f"If you want to recreate the collection, set `force_recreate` "
|
||||||
|
f"parameter to `True`."
|
||||||
|
)
|
||||||
|
|
||||||
|
current_distance_func = (
|
||||||
|
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
if current_distance_func != distance_func:
|
||||||
|
raise QdrantException(
|
||||||
|
f"Existing Qdrant collection is configured for "
|
||||||
|
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||||
|
f"similarity. Please set `distance_func` parameter to "
|
||||||
|
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||||
|
f"recreate the collection, set `force_recreate` parameter to "
|
||||||
|
f"`True`."
|
||||||
|
)
|
||||||
|
except (UnexpectedResponse, RpcError, ValueError):
|
||||||
|
vectors_config = rest.VectorParams(
|
||||||
|
size=vector_size,
|
||||||
|
distance=rest.Distance[distance_func],
|
||||||
|
on_disk=on_disk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If vector name was provided, we're going to use the named vectors feature
|
||||||
|
# with just a single vector.
|
||||||
|
if vector_name is not None:
|
||||||
|
vectors_config = { # type: ignore[assignment]
|
||||||
|
vector_name: vectors_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
client.recreate_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=vectors_config,
|
||||||
|
shard_number=shard_number,
|
||||||
|
replication_factor=replication_factor,
|
||||||
|
write_consistency_factor=write_consistency_factor,
|
||||||
|
on_disk_payload=on_disk_payload,
|
||||||
|
hnsw_config=hnsw_config,
|
||||||
|
optimizers_config=optimizers_config,
|
||||||
|
wal_config=wal_config,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
init_from=init_from,
|
||||||
|
timeout=timeout, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
qdrant = cls(
|
||||||
|
client=client,
|
||||||
|
collection_name=collection_name,
|
||||||
|
embeddings=embedding,
|
||||||
|
content_payload_key=content_payload_key,
|
||||||
|
metadata_payload_key=metadata_payload_key,
|
||||||
|
distance_strategy=distance_func,
|
||||||
|
vector_name=vector_name,
|
||||||
|
)
|
||||||
|
return qdrant
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
"""
|
"""
|
||||||
The 'correct' relevance function
|
The 'correct' relevance function
|
||||||
@ -1648,6 +1815,33 @@ class Qdrant(VectorStore):
|
|||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
||||||
|
"""Embed search texts.
|
||||||
|
|
||||||
|
Used to provide backward compatibility with `embedding_function` argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of floats representing the texts embedding.
|
||||||
|
"""
|
||||||
|
if self.embeddings is not None:
|
||||||
|
embeddings = await self.embeddings.aembed_documents(list(texts))
|
||||||
|
if hasattr(embeddings, "tolist"):
|
||||||
|
embeddings = embeddings.tolist()
|
||||||
|
elif self._embeddings_function is not None:
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = self._embeddings_function(text)
|
||||||
|
if hasattr(embeddings, "tolist"):
|
||||||
|
embedding = embedding.tolist()
|
||||||
|
embeddings.append(embedding)
|
||||||
|
else:
|
||||||
|
raise ValueError("Neither of embeddings or embedding_function is set")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def _generate_rest_batches(
|
def _generate_rest_batches(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -1689,3 +1883,45 @@ class Qdrant(VectorStore):
|
|||||||
]
|
]
|
||||||
|
|
||||||
yield batch_ids, points
|
yield batch_ids, points
|
||||||
|
|
||||||
|
async def _agenerate_rest_batches(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[Sequence[str]] = None,
|
||||||
|
batch_size: int = 64,
|
||||||
|
) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]:
|
||||||
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
|
texts_iterator = iter(texts)
|
||||||
|
metadatas_iterator = iter(metadatas or [])
|
||||||
|
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||||
|
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||||
|
# Take the corresponding metadata and id for each text in a batch
|
||||||
|
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||||
|
batch_ids = list(islice(ids_iterator, batch_size))
|
||||||
|
|
||||||
|
# Generate the embeddings for all the texts in a batch
|
||||||
|
batch_embeddings = await self._aembed_texts(batch_texts)
|
||||||
|
|
||||||
|
points = [
|
||||||
|
rest.PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector
|
||||||
|
if self.vector_name is None
|
||||||
|
else {self.vector_name: vector},
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
for point_id, vector, payload in zip(
|
||||||
|
batch_ids,
|
||||||
|
batch_embeddings,
|
||||||
|
self._build_payloads(
|
||||||
|
batch_texts,
|
||||||
|
batch_metadatas,
|
||||||
|
self.content_payload_key,
|
||||||
|
self.metadata_payload_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield batch_ids, points
|
||||||
|
@ -15,6 +15,9 @@ class FakeEmbeddings(Embeddings):
|
|||||||
Embeddings encode each text as its index."""
|
Embeddings encode each text as its index."""
|
||||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self.embed_documents(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Return constant query embeddings.
|
"""Return constant query embeddings.
|
||||||
Embeddings are identical to embed_documents(texts)[0].
|
Embeddings are identical to embed_documents(texts)[0].
|
||||||
@ -22,6 +25,9 @@ class FakeEmbeddings(Embeddings):
|
|||||||
as it was passed to embed_documents."""
|
as it was passed to embed_documents."""
|
||||||
return [float(1.0)] * 9 + [float(0.0)]
|
return [float(1.0)] * 9 + [float(0.0)]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
return self.embed_query(text)
|
||||||
|
|
||||||
|
|
||||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||||
|
Loading…
Reference in New Issue
Block a user