mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
core: implement a batch_size parameter for CacheBackedEmbeddings (#18070)
**Description:** Currently, `CacheBackedEmbeddings` computes vectors for *all* uncached documents before updating the store. This pull request updates the embedding computation loop to compute embeddings in batches, updating the store after each batch. I noticed this when I tried `CacheBackedEmbeddings` on our 30k document set and the cache directory hadn't appeared on disk after 30 minutes. The motivation is to minimize compute/data loss when problems occur: * If there is a transient embedding failure (e.g. a network outage at the embedding endpoint triggers an exception), at least the completed vectors are written to the store instead of being discarded. * If there is an issue with the store (e.g. no write permissions), the condition is detected early without computing (and discarding!) all the vectors. **Issue:** Implements enhancement #18026. **Testing:** I was unable to run unit tests; details in [this post](https://github.com/langchain-ai/langchain/discussions/15019#discussioncomment-8576684). --------- Signed-off-by: chrispy <chrispy@synopsys.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
89af30807b
commit
305d74c67a
@ -22,10 +22,11 @@
|
|||||||
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
|
"Caching embeddings can be done using a `CacheBackedEmbeddings`. The cache backed embedder is a wrapper around an embedder that caches\n",
|
||||||
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
|
"embeddings in a key-value store. The text is hashed and the hash is used as the key in the cache.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The main supported way to initialized a `CacheBackedEmbeddings` is `from_bytes_store`. This takes in the following parameters:\n",
|
"The main supported way to initialize a `CacheBackedEmbeddings` is `from_bytes_store`. It takes the following parameters:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- underlying_embedder: The embedder to use for embedding.\n",
|
"- underlying_embedder: The embedder to use for embedding.\n",
|
||||||
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
|
||||||
|
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
|
||||||
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."
|
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."
|
||||||
|
@ -165,8 +165,16 @@ class Tee(Generic[T]):
|
|||||||
safetee = Tee
|
safetee = Tee
|
||||||
|
|
||||||
|
|
||||||
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||||
"""Utility batching function."""
|
"""Utility batching function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The size of the batch. If None, returns a single batch.
|
||||||
|
iterable: The iterable to batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterator over the batches.
|
||||||
|
"""
|
||||||
it = iter(iterable)
|
it = iter(iterable)
|
||||||
while True:
|
while True:
|
||||||
chunk = list(islice(it, size))
|
chunk = list(islice(it, size))
|
||||||
|
@ -12,10 +12,11 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Sequence, Union, cast
|
from typing import Callable, List, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.stores import BaseStore, ByteStore
|
from langchain_core.stores import BaseStore, ByteStore
|
||||||
|
from langchain_core.utils.iter import batch_iterate
|
||||||
|
|
||||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||||
|
|
||||||
@ -84,16 +85,20 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
self,
|
self,
|
||||||
underlying_embeddings: Embeddings,
|
underlying_embeddings: Embeddings,
|
||||||
document_embedding_store: BaseStore[str, List[float]],
|
document_embedding_store: BaseStore[str, List[float]],
|
||||||
|
*,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the embedder.
|
"""Initialize the embedder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
underlying_embeddings: the embedder to use for computing embeddings.
|
underlying_embeddings: the embedder to use for computing embeddings.
|
||||||
document_embedding_store: The store to use for caching document embeddings.
|
document_embedding_store: The store to use for caching document embeddings.
|
||||||
|
batch_size: The number of documents to embed between store updates.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.document_embedding_store = document_embedding_store
|
self.document_embedding_store = document_embedding_store
|
||||||
self.underlying_embeddings = underlying_embeddings
|
self.underlying_embeddings = underlying_embeddings
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Embed a list of texts.
|
"""Embed a list of texts.
|
||||||
@ -111,12 +116,12 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
||||||
texts
|
texts
|
||||||
)
|
)
|
||||||
missing_indices: List[int] = [
|
all_missing_indices: List[int] = [
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
i for i, vector in enumerate(vectors) if vector is None
|
||||||
]
|
]
|
||||||
missing_texts = [texts[i] for i in missing_indices]
|
|
||||||
|
|
||||||
if missing_texts:
|
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||||
|
missing_texts = [texts[i] for i in missing_indices]
|
||||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||||
self.document_embedding_store.mset(
|
self.document_embedding_store.mset(
|
||||||
list(zip(missing_texts, missing_vectors))
|
list(zip(missing_texts, missing_vectors))
|
||||||
@ -144,12 +149,14 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors: List[
|
vectors: List[
|
||||||
Union[List[float], None]
|
Union[List[float], None]
|
||||||
] = await self.document_embedding_store.amget(texts)
|
] = await self.document_embedding_store.amget(texts)
|
||||||
missing_indices: List[int] = [
|
all_missing_indices: List[int] = [
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
i for i, vector in enumerate(vectors) if vector is None
|
||||||
]
|
]
|
||||||
missing_texts = [texts[i] for i in missing_indices]
|
|
||||||
|
|
||||||
if missing_texts:
|
# batch_iterate supports None batch_size which returns all elements at once
|
||||||
|
# as a single batch.
|
||||||
|
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||||
|
missing_texts = [texts[i] for i in missing_indices]
|
||||||
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||||
missing_texts
|
missing_texts
|
||||||
)
|
)
|
||||||
@ -210,6 +217,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
document_embedding_cache: ByteStore,
|
document_embedding_cache: ByteStore,
|
||||||
*,
|
*,
|
||||||
namespace: str = "",
|
namespace: str = "",
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> CacheBackedEmbeddings:
|
) -> CacheBackedEmbeddings:
|
||||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||||
|
|
||||||
@ -220,6 +228,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
namespace: The namespace to use for document cache.
|
namespace: The namespace to use for document cache.
|
||||||
This namespace is used to avoid collisions with other caches.
|
This namespace is used to avoid collisions with other caches.
|
||||||
For example, set it to the name of the embedding model used.
|
For example, set it to the name of the embedding model used.
|
||||||
|
batch_size: The number of documents to embed between store updates.
|
||||||
"""
|
"""
|
||||||
namespace = namespace
|
namespace = namespace
|
||||||
key_encoder = _create_key_encoder(namespace)
|
key_encoder = _create_key_encoder(namespace)
|
||||||
@ -229,4 +238,4 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
_value_serializer,
|
_value_serializer,
|
||||||
_value_deserializer,
|
_value_deserializer,
|
||||||
)
|
)
|
||||||
return cls(underlying_embeddings, encoder_backed_store)
|
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
|
||||||
|
@ -13,6 +13,8 @@ class MockEmbeddings(Embeddings):
|
|||||||
# Simulate embedding documents
|
# Simulate embedding documents
|
||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
if text == "RAISE_EXCEPTION":
|
||||||
|
raise ValueError("Simulated embedding failure")
|
||||||
embeddings.append([len(text), len(text) + 1])
|
embeddings.append([len(text), len(text) + 1])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -31,6 +33,16 @@ def cache_embeddings() -> CacheBackedEmbeddings:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
||||||
|
"""Create a cache backed embeddings with a batch_size of 3."""
|
||||||
|
store = InMemoryStore()
|
||||||
|
embeddings = MockEmbeddings()
|
||||||
|
return CacheBackedEmbeddings.from_bytes_store(
|
||||||
|
embeddings, store, namespace="test_namespace", batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
texts = ["1", "22", "a", "333"]
|
texts = ["1", "22", "a", "333"]
|
||||||
vectors = cache_embeddings.embed_documents(texts)
|
vectors = cache_embeddings.embed_documents(texts)
|
||||||
@ -42,6 +54,20 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
|||||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||||
|
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||||
|
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||||
|
try:
|
||||||
|
cache_embeddings_batch.embed_documents(texts)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||||
|
# only the first batch of three embeddings should exist
|
||||||
|
assert len(keys) == 3
|
||||||
|
# UUID is expected to be the same for the same text
|
||||||
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
|
|
||||||
|
|
||||||
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
text = "query_text"
|
text = "query_text"
|
||||||
vector = cache_embeddings.embed_query(text)
|
vector = cache_embeddings.embed_query(text)
|
||||||
@ -62,6 +88,25 @@ async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None
|
|||||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_documents_batch(
|
||||||
|
cache_embeddings_batch: CacheBackedEmbeddings,
|
||||||
|
) -> None:
|
||||||
|
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||||
|
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||||
|
try:
|
||||||
|
await cache_embeddings_batch.aembed_documents(texts)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
keys = [
|
||||||
|
key
|
||||||
|
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||||
|
]
|
||||||
|
# only the first batch of three embeddings should exist
|
||||||
|
assert len(keys) == 3
|
||||||
|
# UUID is expected to be the same for the same text
|
||||||
|
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||||
|
|
||||||
|
|
||||||
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||||
text = "query_text"
|
text = "query_text"
|
||||||
vector = await cache_embeddings.aembed_query(text)
|
vector = await cache_embeddings.aembed_query(text)
|
||||||
|
Loading…
Reference in New Issue
Block a user