diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index 382827991ff..51d3969e539 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -128,6 +128,41 @@ class CacheBackedEmbeddings(Embeddings): List[List[float]], vectors ) # Nones should have been resolved by now + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of texts. + + The method first checks the cache for the embeddings. + If the embeddings are not found, the method uses the underlying embedder + to embed the documents and stores the results in the cache. + + Args: + texts: A list of texts to embed. + + Returns: + A list of embeddings for the given texts. + """ + vectors: List[ + Union[List[float], None] + ] = await self.document_embedding_store.amget(texts) + missing_indices: List[int] = [ + i for i, vector in enumerate(vectors) if vector is None + ] + missing_texts = [texts[i] for i in missing_indices] + + if missing_texts: + missing_vectors = await self.underlying_embeddings.aembed_documents( + missing_texts + ) + await self.document_embedding_store.amset( + list(zip(missing_texts, missing_vectors)) + ) + for index, updated_vector in zip(missing_indices, missing_vectors): + vectors[index] = updated_vector + + return cast( + List[List[float]], vectors + ) # Nones should have been resolved by now + def embed_query(self, text: str) -> List[float]: """Embed query text. @@ -148,6 +183,26 @@ class CacheBackedEmbeddings(Embeddings): """ return self.underlying_embeddings.embed_query(text) + async def aembed_query(self, text: str) -> List[float]: + """Embed query text. + + This method does not support caching at the moment. + + Support for caching queries is easily to implement, but might make + sense to hold off to see the most common patterns. + + If the cache has an eviction policy, we may need to be a bit more careful + about sharing the cache between documents and queries. Generally, + one is OK evicting query caches, but document caches should be kept. + + Args: + text: The text to embed. + + Returns: + The embedding for the given text. + """ + return await self.underlying_embeddings.aembed_query(text) + @classmethod def from_bytes_store( cls, diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index 48c6adbab66..8c24e73b95b 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -47,3 +47,23 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None: vector = cache_embeddings.embed_query(text) expected_vector = [5.0, 6.0] assert vector == expected_vector + + +async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: + texts = ["1", "22", "a", "333"] + vectors = await cache_embeddings.aembed_documents(texts) + expected_vectors: List[List[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]] + assert vectors == expected_vectors + keys = [ + key async for key in cache_embeddings.document_embedding_store.ayield_keys() + ] + assert len(keys) == 4 + # UUID is expected to be the same for the same text + assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" + + +async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None: + text = "query_text" + vector = await cache_embeddings.aembed_query(text) + expected_vector = [5.0, 6.0] + assert vector == expected_vector