mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 02:17:12 +00:00
langchain[minor]: allow CacheBackedEmbeddings to cache queries (#20073)
Add optional caching of queries to cache backed embeddings
This commit is contained in:
parent
a156aace2b
commit
b53548dcda
@ -18,11 +18,12 @@
|
|||||||
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
||||||
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
|
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
|
||||||
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
||||||
|
"- query_embedding_cache: (optional, defaults to `None` or not caching) A [`ByteStore`](/docs/integrations/stores/) for caching query embeddings, or `True` to use the same store as `document_embedding_cache`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**Attention**:\n",
|
"**Attention**:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n",
|
"- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n",
|
||||||
"- Currently `CacheBackedEmbeddings` does not cache embedding created with `embed_query()` `aembed_query()` methods."
|
"- `CacheBackedEmbeddings` does not cache query embeddings by default. To enable query caching, one need to specify a `query_embedding_cache`."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -6,6 +6,7 @@ embeddings for the same text.
|
|||||||
|
|
||||||
The text is hashed and the hash is used as the key in the cache.
|
The text is hashed and the hash is used as the key in the cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -59,6 +60,9 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
If need be, the interface can be extended to accept other implementations
|
If need be, the interface can be extended to accept other implementations
|
||||||
of the value serializer and deserializer, as well as the key encoder.
|
of the value serializer and deserializer, as well as the key encoder.
|
||||||
|
|
||||||
|
Note that by default only document embeddings are cached. To cache query
|
||||||
|
embeddings too, pass in a query_embedding_store to constructor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
.. code-block: python
|
.. code-block: python
|
||||||
@ -87,6 +91,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
document_embedding_store: BaseStore[str, List[float]],
|
document_embedding_store: BaseStore[str, List[float]],
|
||||||
*,
|
*,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
|
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the embedder.
|
"""Initialize the embedder.
|
||||||
|
|
||||||
@ -94,9 +99,12 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
underlying_embeddings: the embedder to use for computing embeddings.
|
underlying_embeddings: the embedder to use for computing embeddings.
|
||||||
document_embedding_store: The store to use for caching document embeddings.
|
document_embedding_store: The store to use for caching document embeddings.
|
||||||
batch_size: The number of documents to embed between store updates.
|
batch_size: The number of documents to embed between store updates.
|
||||||
|
query_embedding_store: The store to use for caching query embeddings.
|
||||||
|
If None, query embeddings are not cached.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.document_embedding_store = document_embedding_store
|
self.document_embedding_store = document_embedding_store
|
||||||
|
self.query_embedding_store = query_embedding_store
|
||||||
self.underlying_embeddings = underlying_embeddings
|
self.underlying_embeddings = underlying_embeddings
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
@ -173,14 +181,8 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
This method does not support caching at the moment.
|
By default, this method does not cache queries. To enable caching, set the
|
||||||
|
`cache_query` parameter to `True` when initializing the embedder.
|
||||||
Support for caching queries is easily to implement, but might make
|
|
||||||
sense to hold off to see the most common patterns.
|
|
||||||
|
|
||||||
If the cache has an eviction policy, we may need to be a bit more careful
|
|
||||||
about sharing the cache between documents and queries. Generally,
|
|
||||||
one is OK evicting query caches, but document caches should be kept.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to embed.
|
text: The text to embed.
|
||||||
@ -188,19 +190,22 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
The embedding for the given text.
|
The embedding for the given text.
|
||||||
"""
|
"""
|
||||||
return self.underlying_embeddings.embed_query(text)
|
if not self.query_embedding_store:
|
||||||
|
return self.underlying_embeddings.embed_query(text)
|
||||||
|
|
||||||
|
(cached,) = self.query_embedding_store.mget([text])
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
vector = self.underlying_embeddings.embed_query(text)
|
||||||
|
self.query_embedding_store.mset([(text, vector)])
|
||||||
|
return vector
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
This method does not support caching at the moment.
|
By default, this method does not cache queries. To enable caching, set the
|
||||||
|
`cache_query` parameter to `True` when initializing the embedder.
|
||||||
Support for caching queries is easily to implement, but might make
|
|
||||||
sense to hold off to see the most common patterns.
|
|
||||||
|
|
||||||
If the cache has an eviction policy, we may need to be a bit more careful
|
|
||||||
about sharing the cache between documents and queries. Generally,
|
|
||||||
one is OK evicting query caches, but document caches should be kept.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to embed.
|
text: The text to embed.
|
||||||
@ -208,7 +213,16 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
The embedding for the given text.
|
The embedding for the given text.
|
||||||
"""
|
"""
|
||||||
return await self.underlying_embeddings.aembed_query(text)
|
if not self.query_embedding_store:
|
||||||
|
return await self.underlying_embeddings.aembed_query(text)
|
||||||
|
|
||||||
|
(cached,) = await self.query_embedding_store.amget([text])
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
vector = await self.underlying_embeddings.aembed_query(text)
|
||||||
|
await self.query_embedding_store.amset([(text, vector)])
|
||||||
|
return vector
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes_store(
|
def from_bytes_store(
|
||||||
@ -218,6 +232,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
*,
|
*,
|
||||||
namespace: str = "",
|
namespace: str = "",
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
|
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||||
) -> CacheBackedEmbeddings:
|
) -> CacheBackedEmbeddings:
|
||||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||||
|
|
||||||
@ -229,13 +244,33 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
This namespace is used to avoid collisions with other caches.
|
This namespace is used to avoid collisions with other caches.
|
||||||
For example, set it to the name of the embedding model used.
|
For example, set it to the name of the embedding model used.
|
||||||
batch_size: The number of documents to embed between store updates.
|
batch_size: The number of documents to embed between store updates.
|
||||||
|
query_embedding_cache: The cache to use for storing query embeddings.
|
||||||
|
True to use the same cache as document embeddings.
|
||||||
|
False to not cache query embeddings.
|
||||||
"""
|
"""
|
||||||
namespace = namespace
|
namespace = namespace
|
||||||
key_encoder = _create_key_encoder(namespace)
|
key_encoder = _create_key_encoder(namespace)
|
||||||
encoder_backed_store = EncoderBackedStore[str, List[float]](
|
document_embedding_store = EncoderBackedStore[str, List[float]](
|
||||||
document_embedding_cache,
|
document_embedding_cache,
|
||||||
key_encoder,
|
key_encoder,
|
||||||
_value_serializer,
|
_value_serializer,
|
||||||
_value_deserializer,
|
_value_deserializer,
|
||||||
)
|
)
|
||||||
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
|
if query_embedding_cache is True:
|
||||||
|
query_embedding_store = document_embedding_store
|
||||||
|
elif query_embedding_cache is False:
|
||||||
|
query_embedding_store = None
|
||||||
|
else:
|
||||||
|
query_embedding_store = EncoderBackedStore[str, List[float]](
|
||||||
|
query_embedding_cache,
|
||||||
|
key_encoder,
|
||||||
|
_value_serializer,
|
||||||
|
_value_deserializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
underlying_embeddings,
|
||||||
|
document_embedding_store,
|
||||||
|
batch_size=batch_size,
|
||||||
|
query_embedding_store=query_embedding_store,
|
||||||
|
)
|
||||||
|
@ -43,6 +43,20 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
|
||||||
|
"""Create a cache backed embeddings with query caching."""
|
||||||
|
doc_store = InMemoryStore()
|
||||||
|
query_store = InMemoryStore()
|
||||||
|
embeddings = MockEmbeddings()
|
||||||
|
return CacheBackedEmbeddings.from_bytes_store(
|
||||||
|
embeddings,
|
||||||
|
document_embedding_cache=doc_store,
|
||||||
|
namespace="test_namespace",
|
||||||
|
query_embedding_cache=query_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
texts = ["1", "22", "a", "333"]
|
texts = ["1", "22", "a", "333"]
|
||||||
vectors = cache_embeddings.embed_documents(texts)
|
vectors = cache_embeddings.embed_documents(texts)
|
||||||
@ -73,6 +87,17 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
|||||||
vector = cache_embeddings.embed_query(text)
|
vector = cache_embeddings.embed_query(text)
|
||||||
expected_vector = [5.0, 6.0]
|
expected_vector = [5.0, 6.0]
|
||||||
assert vector == expected_vector
|
assert vector == expected_vector
|
||||||
|
assert cache_embeddings.query_embedding_store is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
|
||||||
|
text = "query_text"
|
||||||
|
vector = cache_embeddings_with_query.embed_query(text)
|
||||||
|
expected_vector = [5.0, 6.0]
|
||||||
|
assert vector == expected_vector
|
||||||
|
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||||
|
assert len(keys) == 1
|
||||||
|
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||||
|
|
||||||
|
|
||||||
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
@ -112,3 +137,13 @@ async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
|||||||
vector = await cache_embeddings.aembed_query(text)
|
vector = await cache_embeddings.aembed_query(text)
|
||||||
expected_vector = [5.0, 6.0]
|
expected_vector = [5.0, 6.0]
|
||||||
assert vector == expected_vector
|
assert vector == expected_vector
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_query_cached(
|
||||||
|
cache_embeddings_with_query: CacheBackedEmbeddings,
|
||||||
|
) -> None:
|
||||||
|
text = "query_text"
|
||||||
|
await cache_embeddings_with_query.aembed_query(text)
|
||||||
|
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||||
|
assert len(keys) == 1
|
||||||
|
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||||
|
Loading…
Reference in New Issue
Block a user