core: implement a batch_size parameter for CacheBackedEmbeddings (#18070)

**Description:**

Currently, `CacheBackedEmbeddings` computes vectors for *all* uncached
documents before updating the store. This pull request updates the
embedding computation loop to compute embeddings in batches, updating
the store after each batch.

I noticed this when I tried `CacheBackedEmbeddings` on our 30k document
set and the cache directory hadn't appeared on disk after 30 minutes.

The motivation is to minimize compute/data loss when problems occur:

* If there is a transient embedding failure (e.g. a network outage at
the embedding endpoint triggers an exception), at least the completed
vectors are written to the store instead of being discarded.
* If there is an issue with the store (e.g. no write permissions), the
condition is detected early without computing (and discarding!) all the
vectors.

**Issue:**
Implements enhancement #18026.

**Testing:**
I was unable to run unit tests; details in [this
post](https://github.com/langchain-ai/langchain/discussions/15019#discussioncomment-8576684).

---------

Signed-off-by: chrispy <chrispy@synopsys.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Chris Papademetrious 2024-03-19 14:55:43 -04:00 committed by GitHub
parent 89af30807b
commit 305d74c67a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 11 deletions

View File

@ -22,10 +22,11 @@
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n", "Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n", "embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
"\n", "\n",
"The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n", "The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n",
"\n", "\n",
"- underlying_embedder: The embedder to use for embedding.\n", "- underlying_embedder: The embedder to use for embedding.\n",
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n", "- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n", "- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"\n", "\n",
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models." "**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."

View File

@ -165,8 +165,16 @@ class Tee(Generic[T]):
safetee = Tee safetee = Tee
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function.""" """Utility batching function.
Args:
size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch.
Returns:
An iterator over the batches.
"""
it = iter(iterable) it = iter(iterable)
while True: while True:
chunk = list(islice(it, size)) chunk = list(islice(it, size))

View File

@ -12,10 +12,11 @@ import hashlib
import json import json
import uuid import uuid
from functools import partial from functools import partial
from typing import Callable, List, Sequence, Union, cast from typing import Callable, List, Optional, Sequence, Union, cast
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore from langchain.storage.encoder_backed import EncoderBackedStore
@ -84,16 +85,20 @@ class CacheBackedEmbeddings(Embeddings):
self, self,
underlying_embeddings: Embeddings, underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]], document_embedding_store: BaseStore[str, List[float]],
*,
batch_size: Optional[int] = None,
) -> None: ) -> None:
"""Initialize the embedder. """Initialize the embedder.
Args: Args:
underlying_embeddings: the embedder to use for computing embeddings. underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings. document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
""" """
super().__init__() super().__init__()
self.document_embedding_store = document_embedding_store self.document_embedding_store = document_embedding_store
self.underlying_embeddings = underlying_embeddings self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts. """Embed a list of texts.
@ -111,12 +116,12 @@ class CacheBackedEmbeddings(Embeddings):
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget( vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
texts texts
) )
missing_indices: List[int] = [ all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None i for i, vector in enumerate(vectors) if vector is None
] ]
missing_texts = [texts[i] for i in missing_indices]
if missing_texts: for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts) missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset( self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors)) list(zip(missing_texts, missing_vectors))
@ -144,12 +149,14 @@ class CacheBackedEmbeddings(Embeddings):
vectors: List[ vectors: List[
Union[List[float], None] Union[List[float], None]
] = await self.document_embedding_store.amget(texts) ] = await self.document_embedding_store.amget(texts)
missing_indices: List[int] = [ all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None i for i, vector in enumerate(vectors) if vector is None
] ]
missing_texts = [texts[i] for i in missing_indices]
if missing_texts: # batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = await self.underlying_embeddings.aembed_documents( missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts missing_texts
) )
@ -210,6 +217,7 @@ class CacheBackedEmbeddings(Embeddings):
document_embedding_cache: ByteStore, document_embedding_cache: ByteStore,
*, *,
namespace: str = "", namespace: str = "",
batch_size: Optional[int] = None,
) -> CacheBackedEmbeddings: ) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store. """On-ramp that adds the necessary serialization and encoding to the store.
@ -220,6 +228,7 @@ class CacheBackedEmbeddings(Embeddings):
namespace: The namespace to use for document cache. namespace: The namespace to use for document cache.
This namespace is used to avoid collisions with other caches. This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used. For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates.
""" """
namespace = namespace namespace = namespace
key_encoder = _create_key_encoder(namespace) key_encoder = _create_key_encoder(namespace)
@ -229,4 +238,4 @@ class CacheBackedEmbeddings(Embeddings):
_value_serializer, _value_serializer,
_value_deserializer, _value_deserializer,
) )
return cls(underlying_embeddings, encoder_backed_store) return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)

View File

@ -13,6 +13,8 @@ class MockEmbeddings(Embeddings):
# Simulate embedding documents # Simulate embedding documents
embeddings: List[List[float]] = [] embeddings: List[List[float]] = []
for text in texts: for text in texts:
if text == "RAISE_EXCEPTION":
raise ValueError("Simulated embedding failure")
embeddings.append([len(text), len(text) + 1]) embeddings.append([len(text), len(text) + 1])
return embeddings return embeddings
@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings:
) )
@pytest.fixture
def cache_embeddings_batch() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with a batch_size of 3."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings, store, namespace="test_namespace", batch_size=3
)
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"] texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts) vectors = cache_embeddings.embed_documents(texts)
@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
cache_embeddings_batch.embed_documents(texts)
except ValueError:
pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None: def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text" text = "query_text"
vector = cache_embeddings.embed_query(text) vector = cache_embeddings.embed_query(text)
@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12" assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_documents_batch(
cache_embeddings_batch: CacheBackedEmbeddings,
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
await cache_embeddings_batch.aembed_documents(texts)
except ValueError:
pass
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None: async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text" text = "query_text"
vector = await cache_embeddings.aembed_query(text) vector = await cache_embeddings.aembed_query(text)