From b53548dcda8fe1dd820f7db31db6b1f3bff6c360 Mon Sep 17 00:00:00 2001 From: Wang Guan Date: Tue, 14 May 2024 00:18:04 +0900 Subject: [PATCH] langchain[minor]: allow CacheBackedEmbeddings to cache queries (#20073) Add optional caching of queries to cache backed embeddings --- docs/docs/how_to/caching_embeddings.ipynb | 3 +- libs/langchain/langchain/embeddings/cache.py | 75 ++++++++++++++----- .../unit_tests/embeddings/test_caching.py | 35 +++++++++ 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/docs/docs/how_to/caching_embeddings.ipynb b/docs/docs/how_to/caching_embeddings.ipynb index f9d6ad59dc9..232ba7bac0e 100644 --- a/docs/docs/how_to/caching_embeddings.ipynb +++ b/docs/docs/how_to/caching_embeddings.ipynb @@ -18,11 +18,12 @@ "- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n", "- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n", "- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n", + "- query_embedding_cache: (optional, defaults to `None` or not caching) A [`ByteStore`](/docs/integrations/stores/) for caching query embeddings, or `True` to use the same store as `document_embedding_cache`.\n", "\n", "**Attention**:\n", "\n", "- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n", - "- Currently `CacheBackedEmbeddings` does not cache embedding created with `embed_query()` `aembed_query()` methods." + "- `CacheBackedEmbeddings` does not cache query embeddings by default. To enable query caching, one need to specify a `query_embedding_cache`." ] }, { diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index 05b7e6d923e..9cbf71a5aab 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -6,6 +6,7 @@ embeddings for the same text. The text is hashed and the hash is used as the key in the cache. """ + from __future__ import annotations import hashlib @@ -59,6 +60,9 @@ class CacheBackedEmbeddings(Embeddings): If need be, the interface can be extended to accept other implementations of the value serializer and deserializer, as well as the key encoder. + Note that by default only document embeddings are cached. To cache query + embeddings too, pass in a query_embedding_store to constructor. + Examples: .. code-block: python @@ -87,6 +91,7 @@ class CacheBackedEmbeddings(Embeddings): document_embedding_store: BaseStore[str, List[float]], *, batch_size: Optional[int] = None, + query_embedding_store: Optional[BaseStore[str, List[float]]] = None, ) -> None: """Initialize the embedder. @@ -94,9 +99,12 @@ class CacheBackedEmbeddings(Embeddings): underlying_embeddings: the embedder to use for computing embeddings. document_embedding_store: The store to use for caching document embeddings. batch_size: The number of documents to embed between store updates. + query_embedding_store: The store to use for caching query embeddings. + If None, query embeddings are not cached. """ super().__init__() self.document_embedding_store = document_embedding_store + self.query_embedding_store = query_embedding_store self.underlying_embeddings = underlying_embeddings self.batch_size = batch_size @@ -173,14 +181,8 @@ class CacheBackedEmbeddings(Embeddings): def embed_query(self, text: str) -> List[float]: """Embed query text. - This method does not support caching at the moment. - - Support for caching queries is easily to implement, but might make - sense to hold off to see the most common patterns. - - If the cache has an eviction policy, we may need to be a bit more careful - about sharing the cache between documents and queries. Generally, - one is OK evicting query caches, but document caches should be kept. + By default, this method does not cache queries. To enable caching, set the + `cache_query` parameter to `True` when initializing the embedder. Args: text: The text to embed. @@ -188,19 +190,22 @@ class CacheBackedEmbeddings(Embeddings): Returns: The embedding for the given text. """ - return self.underlying_embeddings.embed_query(text) + if not self.query_embedding_store: + return self.underlying_embeddings.embed_query(text) + + (cached,) = self.query_embedding_store.mget([text]) + if cached is not None: + return cached + + vector = self.underlying_embeddings.embed_query(text) + self.query_embedding_store.mset([(text, vector)]) + return vector async def aembed_query(self, text: str) -> List[float]: """Embed query text. - This method does not support caching at the moment. - - Support for caching queries is easily to implement, but might make - sense to hold off to see the most common patterns. - - If the cache has an eviction policy, we may need to be a bit more careful - about sharing the cache between documents and queries. Generally, - one is OK evicting query caches, but document caches should be kept. + By default, this method does not cache queries. To enable caching, set the + `cache_query` parameter to `True` when initializing the embedder. Args: text: The text to embed. @@ -208,7 +213,16 @@ class CacheBackedEmbeddings(Embeddings): Returns: The embedding for the given text. """ - return await self.underlying_embeddings.aembed_query(text) + if not self.query_embedding_store: + return await self.underlying_embeddings.aembed_query(text) + + (cached,) = await self.query_embedding_store.amget([text]) + if cached is not None: + return cached + + vector = await self.underlying_embeddings.aembed_query(text) + await self.query_embedding_store.amset([(text, vector)]) + return vector @classmethod def from_bytes_store( @@ -218,6 +232,7 @@ class CacheBackedEmbeddings(Embeddings): *, namespace: str = "", batch_size: Optional[int] = None, + query_embedding_cache: Union[bool, ByteStore] = False, ) -> CacheBackedEmbeddings: """On-ramp that adds the necessary serialization and encoding to the store. @@ -229,13 +244,33 @@ class CacheBackedEmbeddings(Embeddings): This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used. batch_size: The number of documents to embed between store updates. + query_embedding_cache: The cache to use for storing query embeddings. + True to use the same cache as document embeddings. + False to not cache query embeddings. """ namespace = namespace key_encoder = _create_key_encoder(namespace) - encoder_backed_store = EncoderBackedStore[str, List[float]]( + document_embedding_store = EncoderBackedStore[str, List[float]]( document_embedding_cache, key_encoder, _value_serializer, _value_deserializer, ) - return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size) + if query_embedding_cache is True: + query_embedding_store = document_embedding_store + elif query_embedding_cache is False: + query_embedding_store = None + else: + query_embedding_store = EncoderBackedStore[str, List[float]]( + query_embedding_cache, + key_encoder, + _value_serializer, + _value_deserializer, + ) + + return cls( + underlying_embeddings, + document_embedding_store, + batch_size=batch_size, + query_embedding_store=query_embedding_store, + ) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index 154f248d649..e24bbb1a22c 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -43,6 +43,20 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings: ) +@pytest.fixture +def cache_embeddings_with_query() -> CacheBackedEmbeddings: + """Create a cache backed embeddings with query caching.""" + doc_store = InMemoryStore() + query_store = InMemoryStore() + embeddings = MockEmbeddings() + return CacheBackedEmbeddings.from_bytes_store( + embeddings, + document_embedding_cache=doc_store, + namespace="test_namespace", + query_embedding_cache=query_store, + ) + + def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: texts = ["1", "22", "a", "333"] vectors = cache_embeddings.embed_documents(texts) @@ -73,6 +87,17 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None: vector = cache_embeddings.embed_query(text) expected_vector = [5.0, 6.0] assert vector == expected_vector + assert cache_embeddings.query_embedding_store is None + + +def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None: + text = "query_text" + vector = cache_embeddings_with_query.embed_query(text) + expected_vector = [5.0, 6.0] + assert vector == expected_vector + keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr] + assert len(keys) == 1 + assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15" async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: @@ -112,3 +137,13 @@ async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None: vector = await cache_embeddings.aembed_query(text) expected_vector = [5.0, 6.0] assert vector == expected_vector + + +async def test_aembed_query_cached( + cache_embeddings_with_query: CacheBackedEmbeddings, +) -> None: + text = "query_text" + await cache_embeddings_with_query.aembed_query(text) + keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr] + assert len(keys) == 1 + assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"