mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-20 01:49:51 +00:00
langchain[minor]: allow CacheBackedEmbeddings to cache queries (#20073)
Add optional caching of queries to cache backed embeddings
This commit is contained in:
parent
a156aace2b
commit
b53548dcda
@ -18,11 +18,12 @@
|
||||
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
||||
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
|
||||
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
||||
"- query_embedding_cache: (optional, defaults to `None` or not caching) A [`ByteStore`](/docs/integrations/stores/) for caching query embeddings, or `True` to use the same store as `document_embedding_cache`.\n",
|
||||
"\n",
|
||||
"**Attention**:\n",
|
||||
"\n",
|
||||
"- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n",
|
||||
"- Currently `CacheBackedEmbeddings` does not cache embedding created with `embed_query()` `aembed_query()` methods."
|
||||
"- `CacheBackedEmbeddings` does not cache query embeddings by default. To enable query caching, one need to specify a `query_embedding_cache`."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -6,6 +6,7 @@ embeddings for the same text.
|
||||
|
||||
The text is hashed and the hash is used as the key in the cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
@ -59,6 +60,9 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
If need be, the interface can be extended to accept other implementations
|
||||
of the value serializer and deserializer, as well as the key encoder.
|
||||
|
||||
Note that by default only document embeddings are cached. To cache query
|
||||
embeddings too, pass in a query_embedding_store to constructor.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block: python
|
||||
@ -87,6 +91,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
document_embedding_store: BaseStore[str, List[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
@ -94,9 +99,12 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
underlying_embeddings: the embedder to use for computing embeddings.
|
||||
document_embedding_store: The store to use for caching document embeddings.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
query_embedding_store: The store to use for caching query embeddings.
|
||||
If None, query embeddings are not cached.
|
||||
"""
|
||||
super().__init__()
|
||||
self.document_embedding_store = document_embedding_store
|
||||
self.query_embedding_store = query_embedding_store
|
||||
self.underlying_embeddings = underlying_embeddings
|
||||
self.batch_size = batch_size
|
||||
|
||||
@ -173,14 +181,8 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text.
|
||||
|
||||
This method does not support caching at the moment.
|
||||
|
||||
Support for caching queries is easily to implement, but might make
|
||||
sense to hold off to see the most common patterns.
|
||||
|
||||
If the cache has an eviction policy, we may need to be a bit more careful
|
||||
about sharing the cache between documents and queries. Generally,
|
||||
one is OK evicting query caches, but document caches should be kept.
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
`cache_query` parameter to `True` when initializing the embedder.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
@ -188,19 +190,22 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
The embedding for the given text.
|
||||
"""
|
||||
if not self.query_embedding_store:
|
||||
return self.underlying_embeddings.embed_query(text)
|
||||
|
||||
(cached,) = self.query_embedding_store.mget([text])
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
vector = self.underlying_embeddings.embed_query(text)
|
||||
self.query_embedding_store.mset([(text, vector)])
|
||||
return vector
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text.
|
||||
|
||||
This method does not support caching at the moment.
|
||||
|
||||
Support for caching queries is easily to implement, but might make
|
||||
sense to hold off to see the most common patterns.
|
||||
|
||||
If the cache has an eviction policy, we may need to be a bit more careful
|
||||
about sharing the cache between documents and queries. Generally,
|
||||
one is OK evicting query caches, but document caches should be kept.
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
`cache_query` parameter to `True` when initializing the embedder.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
@ -208,8 +213,17 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
The embedding for the given text.
|
||||
"""
|
||||
if not self.query_embedding_store:
|
||||
return await self.underlying_embeddings.aembed_query(text)
|
||||
|
||||
(cached,) = await self.query_embedding_store.amget([text])
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
vector = await self.underlying_embeddings.aembed_query(text)
|
||||
await self.query_embedding_store.amset([(text, vector)])
|
||||
return vector
|
||||
|
||||
@classmethod
|
||||
def from_bytes_store(
|
||||
cls,
|
||||
@ -218,6 +232,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
@ -229,13 +244,33 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
This namespace is used to avoid collisions with other caches.
|
||||
For example, set it to the name of the embedding model used.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
query_embedding_cache: The cache to use for storing query embeddings.
|
||||
True to use the same cache as document embeddings.
|
||||
False to not cache query embeddings.
|
||||
"""
|
||||
namespace = namespace
|
||||
key_encoder = _create_key_encoder(namespace)
|
||||
encoder_backed_store = EncoderBackedStore[str, List[float]](
|
||||
document_embedding_store = EncoderBackedStore[str, List[float]](
|
||||
document_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
|
||||
if query_embedding_cache is True:
|
||||
query_embedding_store = document_embedding_store
|
||||
elif query_embedding_cache is False:
|
||||
query_embedding_store = None
|
||||
else:
|
||||
query_embedding_store = EncoderBackedStore[str, List[float]](
|
||||
query_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
|
||||
return cls(
|
||||
underlying_embeddings,
|
||||
document_embedding_store,
|
||||
batch_size=batch_size,
|
||||
query_embedding_store=query_embedding_store,
|
||||
)
|
||||
|
@ -43,6 +43,20 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings with query caching."""
|
||||
doc_store = InMemoryStore()
|
||||
query_store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings,
|
||||
document_embedding_cache=doc_store,
|
||||
namespace="test_namespace",
|
||||
query_embedding_cache=query_store,
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = cache_embeddings.embed_documents(texts)
|
||||
@ -73,6 +87,17 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
vector = cache_embeddings.embed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
assert cache_embeddings.query_embedding_store is None
|
||||
|
||||
|
||||
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = cache_embeddings_with_query.embed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||
assert len(keys) == 1
|
||||
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||
|
||||
|
||||
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
@ -112,3 +137,13 @@ async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
vector = await cache_embeddings.aembed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
|
||||
|
||||
async def test_aembed_query_cached(
|
||||
cache_embeddings_with_query: CacheBackedEmbeddings,
|
||||
) -> None:
|
||||
text = "query_text"
|
||||
await cache_embeddings_with_query.aembed_query(text)
|
||||
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||
assert len(keys) == 1
|
||||
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||
|
Loading…
Reference in New Issue
Block a user