langchain[minor]: allow CacheBackedEmbeddings to cache queries (#20073)

Add optional caching of queries to cache backed embeddings
This commit is contained in:
Wang Guan 2024-05-14 00:18:04 +09:00 committed by GitHub
parent a156aace2b
commit b53548dcda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 21 deletions

View File

@ -18,11 +18,12 @@
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n", "- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n", "- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n", "- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"- query_embedding_cache: (optional, defaults to `None` or not caching) A [`ByteStore`](/docs/integrations/stores/) for caching query embeddings, or `True` to use the same store as `document_embedding_cache`.\n",
"\n", "\n",
"**Attention**:\n", "**Attention**:\n",
"\n", "\n",
"- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n", "- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n",
"- Currently `CacheBackedEmbeddings` does not cache embedding created with `embed_query()` `aembed_query()` methods." "- `CacheBackedEmbeddings` does not cache query embeddings by default. To enable query caching, one need to specify a `query_embedding_cache`."
] ]
}, },
{ {

View File

@ -6,6 +6,7 @@ embeddings for the same text.
The text is hashed and the hash is used as the key in the cache. The text is hashed and the hash is used as the key in the cache.
""" """
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
@ -59,6 +60,9 @@ class CacheBackedEmbeddings(Embeddings):
If need be, the interface can be extended to accept other implementations If need be, the interface can be extended to accept other implementations
of the value serializer and deserializer, as well as the key encoder. of the value serializer and deserializer, as well as the key encoder.
Note that by default only document embeddings are cached. To cache query
embeddings too, pass in a query_embedding_store to constructor.
Examples: Examples:
.. code-block: python .. code-block: python
@ -87,6 +91,7 @@ class CacheBackedEmbeddings(Embeddings):
document_embedding_store: BaseStore[str, List[float]], document_embedding_store: BaseStore[str, List[float]],
*, *,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
) -> None: ) -> None:
"""Initialize the embedder. """Initialize the embedder.
@ -94,9 +99,12 @@ class CacheBackedEmbeddings(Embeddings):
underlying_embeddings: the embedder to use for computing embeddings. underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings. document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates. batch_size: The number of documents to embed between store updates.
query_embedding_store: The store to use for caching query embeddings.
If None, query embeddings are not cached.
""" """
super().__init__() super().__init__()
self.document_embedding_store = document_embedding_store self.document_embedding_store = document_embedding_store
self.query_embedding_store = query_embedding_store
self.underlying_embeddings = underlying_embeddings self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size self.batch_size = batch_size
@ -173,14 +181,8 @@ class CacheBackedEmbeddings(Embeddings):
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed query text. """Embed query text.
This method does not support caching at the moment. By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
Args: Args:
text: The text to embed. text: The text to embed.
@ -188,19 +190,22 @@ class CacheBackedEmbeddings(Embeddings):
Returns: Returns:
The embedding for the given text. The embedding for the given text.
""" """
return self.underlying_embeddings.embed_query(text) if not self.query_embedding_store:
return self.underlying_embeddings.embed_query(text)
(cached,) = self.query_embedding_store.mget([text])
if cached is not None:
return cached
vector = self.underlying_embeddings.embed_query(text)
self.query_embedding_store.mset([(text, vector)])
return vector
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:
"""Embed query text. """Embed query text.
This method does not support caching at the moment. By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
Args: Args:
text: The text to embed. text: The text to embed.
@ -208,7 +213,16 @@ class CacheBackedEmbeddings(Embeddings):
Returns: Returns:
The embedding for the given text. The embedding for the given text.
""" """
return await self.underlying_embeddings.aembed_query(text) if not self.query_embedding_store:
return await self.underlying_embeddings.aembed_query(text)
(cached,) = await self.query_embedding_store.amget([text])
if cached is not None:
return cached
vector = await self.underlying_embeddings.aembed_query(text)
await self.query_embedding_store.amset([(text, vector)])
return vector
@classmethod @classmethod
def from_bytes_store( def from_bytes_store(
@ -218,6 +232,7 @@ class CacheBackedEmbeddings(Embeddings):
*, *,
namespace: str = "", namespace: str = "",
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
) -> CacheBackedEmbeddings: ) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store. """On-ramp that adds the necessary serialization and encoding to the store.
@ -229,13 +244,33 @@ class CacheBackedEmbeddings(Embeddings):
This namespace is used to avoid collisions with other caches. This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used. For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates. batch_size: The number of documents to embed between store updates.
query_embedding_cache: The cache to use for storing query embeddings.
True to use the same cache as document embeddings.
False to not cache query embeddings.
""" """
namespace = namespace namespace = namespace
key_encoder = _create_key_encoder(namespace) key_encoder = _create_key_encoder(namespace)
encoder_backed_store = EncoderBackedStore[str, List[float]]( document_embedding_store = EncoderBackedStore[str, List[float]](
document_embedding_cache, document_embedding_cache,
key_encoder, key_encoder,
_value_serializer, _value_serializer,
_value_deserializer, _value_deserializer,
) )
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size) if query_embedding_cache is True:
query_embedding_store = document_embedding_store
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, List[float]](
query_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(
underlying_embeddings,
document_embedding_store,
batch_size=batch_size,
query_embedding_store=query_embedding_store,
)

View File

@ -43,6 +43,20 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings:
) )
@pytest.fixture
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with query caching."""
doc_store = InMemoryStore()
query_store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
document_embedding_cache=doc_store,
namespace="test_namespace",
query_embedding_cache=query_store,
)
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"] texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts) vectors = cache_embeddings.embed_documents(texts)
@ -73,6 +87,17 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
vector = cache_embeddings.embed_query(text) vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0] expected_vector = [5.0, 6.0]
assert vector == expected_vector assert vector == expected_vector
assert cache_embeddings.query_embedding_store is None
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings_with_query.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
@ -112,3 +137,13 @@ async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
vector = await cache_embeddings.aembed_query(text) vector = await cache_embeddings.aembed_query(text)
expected_vector = [5.0, 6.0] expected_vector = [5.0, 6.0]
assert vector == expected_vector assert vector == expected_vector
async def test_aembed_query_cached(
cache_embeddings_with_query: CacheBackedEmbeddings,
) -> None:
text = "query_text"
await cache_embeddings_with_query.aembed_query(text)
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"