mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Add async methods to CacheBackedEmbeddings (#16873)
Adds async methods to CacheBackedEmbeddings
This commit is contained in:
parent
dd68a8716e
commit
a8f530bc4d
@ -128,6 +128,41 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
List[List[float]], vectors
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of texts.
|
||||
|
||||
The method first checks the cache for the embeddings.
|
||||
If the embeddings are not found, the method uses the underlying embedder
|
||||
to embed the documents and stores the results in the cache.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: List[
|
||||
Union[List[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
missing_indices: List[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
|
||||
if missing_texts:
|
||||
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||
missing_texts
|
||||
)
|
||||
await self.document_embedding_store.amset(
|
||||
list(zip(missing_texts, missing_vectors))
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
List[List[float]], vectors
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text.
|
||||
|
||||
@ -148,6 +183,26 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
"""
|
||||
return self.underlying_embeddings.embed_query(text)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text.
|
||||
|
||||
This method does not support caching at the moment.
|
||||
|
||||
Support for caching queries is easily to implement, but might make
|
||||
sense to hold off to see the most common patterns.
|
||||
|
||||
If the cache has an eviction policy, we may need to be a bit more careful
|
||||
about sharing the cache between documents and queries. Generally,
|
||||
one is OK evicting query caches, but document caches should be kept.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
The embedding for the given text.
|
||||
"""
|
||||
return await self.underlying_embeddings.aembed_query(text)
|
||||
|
||||
@classmethod
|
||||
def from_bytes_store(
|
||||
cls,
|
||||
|
@ -47,3 +47,23 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
vector = cache_embeddings.embed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
|
||||
|
||||
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = await cache_embeddings.aembed_documents(texts)
|
||||
expected_vectors: List[List[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||
assert vectors == expected_vectors
|
||||
keys = [
|
||||
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
||||
]
|
||||
assert len(keys) == 4
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = await cache_embeddings.aembed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
|
Loading…
Reference in New Issue
Block a user