mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 18:39:57 +00:00
core: implement a batch_size parameter for CacheBackedEmbeddings (#18070)
**Description:** Currently, `CacheBackedEmbeddings` computes vectors for *all* uncached documents before updating the store. This pull request updates the embedding computation loop to compute embeddings in batches, updating the store after each batch. I noticed this when I tried `CacheBackedEmbeddings` on our 30k document set and the cache directory hadn't appeared on disk after 30 minutes. The motivation is to minimize compute/data loss when problems occur: * If there is a transient embedding failure (e.g. a network outage at the embedding endpoint triggers an exception), at least the completed vectors are written to the store instead of being discarded. * If there is an issue with the store (e.g. no write permissions), the condition is detected early without computing (and discarding!) all the vectors. **Issue:** Implements enhancement #18026. **Testing:** I was unable to run unit tests; details in [this post](https://github.com/langchain-ai/langchain/discussions/15019#discussioncomment-8576684). --------- Signed-off-by: chrispy <chrispy@synopsys.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
89af30807b
commit
305d74c67a
@@ -12,10 +12,11 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Callable, List, Sequence, Union, cast
|
||||
from typing import Callable, List, Optional, Sequence, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
@@ -84,16 +85,20 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
self,
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, List[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
Args:
|
||||
underlying_embeddings: the embedder to use for computing embeddings.
|
||||
document_embedding_store: The store to use for caching document embeddings.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
"""
|
||||
super().__init__()
|
||||
self.document_embedding_store = document_embedding_store
|
||||
self.underlying_embeddings = underlying_embeddings
|
||||
self.batch_size = batch_size
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of texts.
|
||||
@@ -111,12 +116,12 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
||||
texts
|
||||
)
|
||||
missing_indices: List[int] = [
|
||||
all_missing_indices: List[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
|
||||
if missing_texts:
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||
self.document_embedding_store.mset(
|
||||
list(zip(missing_texts, missing_vectors))
|
||||
@@ -144,12 +149,14 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors: List[
|
||||
Union[List[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
missing_indices: List[int] = [
|
||||
all_missing_indices: List[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
|
||||
if missing_texts:
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
# as a single batch.
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||
missing_texts
|
||||
)
|
||||
@@ -210,6 +217,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
document_embedding_cache: ByteStore,
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
@@ -220,6 +228,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
namespace: The namespace to use for document cache.
|
||||
This namespace is used to avoid collisions with other caches.
|
||||
For example, set it to the name of the embedding model used.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
"""
|
||||
namespace = namespace
|
||||
key_encoder = _create_key_encoder(namespace)
|
||||
@@ -229,4 +238,4 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
return cls(underlying_embeddings, encoder_backed_store)
|
||||
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
|
||||
|
@@ -13,6 +13,8 @@ class MockEmbeddings(Embeddings):
|
||||
# Simulate embedding documents
|
||||
embeddings: List[List[float]] = []
|
||||
for text in texts:
|
||||
if text == "RAISE_EXCEPTION":
|
||||
raise ValueError("Simulated embedding failure")
|
||||
embeddings.append([len(text), len(text) + 1])
|
||||
return embeddings
|
||||
|
||||
@@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings with a batch_size of 3."""
|
||||
store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings, store, namespace="test_namespace", batch_size=3
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = cache_embeddings.embed_documents(texts)
|
||||
@@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
cache_embeddings_batch.embed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = cache_embeddings.embed_query(text)
|
||||
@@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_documents_batch(
|
||||
cache_embeddings_batch: CacheBackedEmbeddings,
|
||||
) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
]
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = await cache_embeddings.aembed_query(text)
|
||||
|
Reference in New Issue
Block a user