mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
core: implement a batch_size parameter for CacheBackedEmbeddings (#18070)
**Description:** Currently, `CacheBackedEmbeddings` computes vectors for *all* uncached documents before updating the store. This pull request updates the embedding computation loop to compute embeddings in batches, updating the store after each batch. I noticed this when I tried `CacheBackedEmbeddings` on our 30k document set and the cache directory hadn't appeared on disk after 30 minutes. The motivation is to minimize compute/data loss when problems occur: * If there is a transient embedding failure (e.g. a network outage at the embedding endpoint triggers an exception), at least the completed vectors are written to the store instead of being discarded. * If there is an issue with the store (e.g. no write permissions), the condition is detected early without computing (and discarding!) all the vectors. **Issue:** Implements enhancement #18026. **Testing:** I was unable to run unit tests; details in [this post](https://github.com/langchain-ai/langchain/discussions/15019#discussioncomment-8576684). --------- Signed-off-by: chrispy <chrispy@synopsys.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
89af30807b
commit
305d74c67a
@ -22,10 +22,11 @@
|
||||
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
|
||||
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
|
||||
"\n",
|
||||
"The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n",
|
||||
"The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n",
|
||||
"\n",
|
||||
"- underlying_embedder: The embedder to use for embedding.\n",
|
||||
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
||||
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
|
||||
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
||||
"\n",
|
||||
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."
|
||||
|
@ -165,8 +165,16 @@ class Tee(Generic[T]):
|
||||
safetee = Tee
|
||||
|
||||
|
||||
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
"""Utility batching function."""
|
||||
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
"""Utility batching function.
|
||||
|
||||
Args:
|
||||
size: The size of the batch. If None, returns a single batch.
|
||||
iterable: The iterable to batch.
|
||||
|
||||
Returns:
|
||||
An iterator over the batches.
|
||||
"""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(islice(it, size))
|
||||
|
@ -12,10 +12,11 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Callable, List, Sequence, Union, cast
|
||||
from typing import Callable, List, Optional, Sequence, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
@ -84,16 +85,20 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
self,
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, List[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
Args:
|
||||
underlying_embeddings: the embedder to use for computing embeddings.
|
||||
document_embedding_store: The store to use for caching document embeddings.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
"""
|
||||
super().__init__()
|
||||
self.document_embedding_store = document_embedding_store
|
||||
self.underlying_embeddings = underlying_embeddings
|
||||
self.batch_size = batch_size
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of texts.
|
||||
@ -111,12 +116,12 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
||||
texts
|
||||
)
|
||||
missing_indices: List[int] = [
|
||||
all_missing_indices: List[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
|
||||
if missing_texts:
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||
self.document_embedding_store.mset(
|
||||
list(zip(missing_texts, missing_vectors))
|
||||
@ -144,12 +149,14 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: List[
|
||||
Union[List[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
missing_indices: List[int] = [
|
||||
all_missing_indices: List[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
|
||||
if missing_texts:
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
# as a single batch.
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||
missing_texts
|
||||
)
|
||||
@ -210,6 +217,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
document_embedding_cache: ByteStore,
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
@ -220,6 +228,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
namespace: The namespace to use for document cache.
|
||||
This namespace is used to avoid collisions with other caches.
|
||||
For example, set it to the name of the embedding model used.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
"""
|
||||
namespace = namespace
|
||||
key_encoder = _create_key_encoder(namespace)
|
||||
@ -229,4 +238,4 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
return cls(underlying_embeddings, encoder_backed_store)
|
||||
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
|
||||
|
@ -13,6 +13,8 @@ class MockEmbeddings(Embeddings):
|
||||
# Simulate embedding documents
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
if text == "RAISE_EXCEPTION":
|
||||
raise ValueError("Simulated embedding failure")
|
||||
embeddings.append([len(text), len(text) + 1])
|
||||
return embeddings
|
||||
|
||||
@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings with a batch_size of 3."""
|
||||
store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings, store, namespace="test_namespace", batch_size=3
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = cache_embeddings.embed_documents(texts)
|
||||
@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
cache_embeddings_batch.embed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = cache_embeddings.embed_query(text)
|
||||
@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_documents_batch(
|
||||
cache_embeddings_batch: CacheBackedEmbeddings,
|
||||
) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
]
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = await cache_embeddings.aembed_query(text)
|
||||
|
Loading…
Reference in New Issue
Block a user