mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 22:53:30 +00:00
Add async methods to CacheBackedEmbeddings (#16873)
Adds async methods to CacheBackedEmbeddings
This commit is contained in:
parent
dd68a8716e
commit
a8f530bc4d
@ -128,6 +128,41 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
List[List[float]], vectors
|
List[List[float]], vectors
|
||||||
) # Nones should have been resolved by now
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed a list of texts.
|
||||||
|
|
||||||
|
The method first checks the cache for the embeddings.
|
||||||
|
If the embeddings are not found, the method uses the underlying embedder
|
||||||
|
to embed the documents and stores the results in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: A list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of embeddings for the given texts.
|
||||||
|
"""
|
||||||
|
vectors: List[
|
||||||
|
Union[List[float], None]
|
||||||
|
] = await self.document_embedding_store.amget(texts)
|
||||||
|
missing_indices: List[int] = [
|
||||||
|
i for i, vector in enumerate(vectors) if vector is None
|
||||||
|
]
|
||||||
|
missing_texts = [texts[i] for i in missing_indices]
|
||||||
|
|
||||||
|
if missing_texts:
|
||||||
|
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||||
|
missing_texts
|
||||||
|
)
|
||||||
|
await self.document_embedding_store.amset(
|
||||||
|
list(zip(missing_texts, missing_vectors))
|
||||||
|
)
|
||||||
|
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||||
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
|
return cast(
|
||||||
|
List[List[float]], vectors
|
||||||
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
@ -148,6 +183,26 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
"""
|
"""
|
||||||
return self.underlying_embeddings.embed_query(text)
|
return self.underlying_embeddings.embed_query(text)
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed query text.
|
||||||
|
|
||||||
|
This method does not support caching at the moment.
|
||||||
|
|
||||||
|
Support for caching queries is easily to implement, but might make
|
||||||
|
sense to hold off to see the most common patterns.
|
||||||
|
|
||||||
|
If the cache has an eviction policy, we may need to be a bit more careful
|
||||||
|
about sharing the cache between documents and queries. Generally,
|
||||||
|
one is OK evicting query caches, but document caches should be kept.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding for the given text.
|
||||||
|
"""
|
||||||
|
return await self.underlying_embeddings.aembed_query(text)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes_store(
|
def from_bytes_store(
|
||||||
cls,
|
cls,
|
||||||
|
@ -47,3 +47,23 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
|||||||
vector = cache_embeddings.embed_query(text)
|
vector = cache_embeddings.embed_query(text)
|
||||||
expected_vector = [5.0, 6.0]
|
expected_vector = [5.0, 6.0]
|
||||||
assert vector == expected_vector
|
assert vector == expected_vector
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
|
texts = ["1", "22", "a", "333"]
|
||||||
|
vectors = await cache_embeddings.aembed_documents(texts)
|
||||||
|
expected_vectors: List[List[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||||
|
assert vectors == expected_vectors
|
||||||
|
keys = [
|
||||||
|
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
||||||
|
]
|
||||||
|
assert len(keys) == 4
|
||||||
|
# UUID is expected to be the same for the same text
|
||||||
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
|
text = "query_text"
|
||||||
|
vector = await cache_embeddings.aembed_query(text)
|
||||||
|
expected_vector = [5.0, 6.0]
|
||||||
|
assert vector == expected_vector
|
||||||
|
Loading…
Reference in New Issue
Block a user