Add embeddings cache (#8976)

This PR adds the ability to temporarily cache or persistently store
embeddings. 

A notebook has been included showing how to set up the cache and how to
use it with a vectorstore.
This commit is contained in:
Eugene Yurtsev
2023-08-10 11:15:30 -04:00
committed by GitHub
parent 6e14f9548b
commit 5e05ba2140
4 changed files with 695 additions and 2 deletions

View File

@@ -0,0 +1,467 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bf4061ce",
"metadata": {},
"source": [
"# Caching Embeddings\n",
"\n",
"Embeddings can be stored or temporarily cached to avoid needing to recompute them.\n",
"\n",
"Caching embeddings can be done using a `CacheBackedEmbedder`.\n",
"\n",
"The cache backed embedder is a wrapper around an embedder that caches\n",
"embeddings in a key-value store. \n",
"\n",
"The text is hashed and the hash is used as the key in the cache.\n",
"\n",
"\n",
"The main supported way to initialized a `CacheBackedEmbedder` is `from_bytes_store`. This takes in the following parameters:\n",
"\n",
"- underlying_embedder: The embedder to use for embedding.\n",
"- document_embedding_cache: The cache to use for storing document embeddings.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"\n",
"**Attention**: Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a463c3c2-749b-40d1-a433-84f68a1cd1c7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.embeddings import CacheBackedEmbedder\n",
"from langchain.storage import InMemoryStore\n",
"from langchain.storage import LocalFileStore\n",
"from langchain.embeddings import OpenAIEmbeddings"
]
},
{
"cell_type": "markdown",
"id": "564c9801-29f0-4452-aeac-527382e2c0e8",
"metadata": {},
"source": [
"## In Memory\n",
"\n",
"This section shows how to set up an in memory cache for embeddings. This type of cache is primarily \n",
"useful for unit tests or prototyping. Do **not** use this cache if you need to actually store the embeddings."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "13bd1c5b-b7ba-4394-957c-7d5b5a841972",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"store = InMemoryStore()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9d99885f-99e1-498c-904d-6db539ac9466",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"underlying_embedder = OpenAIEmbeddings()\n",
"embedder = CacheBackedEmbedder.from_bytes_store(\n",
" underlying_embedder, store, namespace=underlying_embedder.model\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "682eb5d4-0b7a-4dac-b8fb-3de4ca6e421c",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 405 ms, sys: 32.9 ms, total: 438 ms\n",
"Wall time: 715 ms\n"
]
}
],
"source": [
"%%time\n",
"embeddings = embedder.embed_documents([\"hello\", \"goodbye\"])"
]
},
{
"cell_type": "markdown",
"id": "95233026-147f-49d1-bd87-e1e8b88ebdbc",
"metadata": {},
"source": [
"The second time we try to embed the embedding time is only 2 ms because the embeddings are looked up in the cache."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f819c3ff-a212-4d06-a5f7-5eb1435c1feb",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.55 ms, sys: 436 µs, total: 1.99 ms\n",
"Wall time: 1.99 ms\n"
]
}
],
"source": [
"%%time\n",
"embeddings_from_cache = embedder.embed_documents([\"hello\", \"goodbye\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ec38fb72-90a9-4687-a483-c62c87d1f4dd",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings == embeddings_from_cache"
]
},
{
"cell_type": "markdown",
"id": "f6cbe100-8587-4830-b207-fb8b524a9854",
"metadata": {},
"source": [
"## File system\n",
"\n",
"This section covers how to use a file system store"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a0070271-0809-4528-97e0-2a88216846f3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"fs = LocalFileStore(\"./cache/\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0b20e9fe-f57f-4d7c-9f81-105c5f8726f4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"embedder2 = CacheBackedEmbedder.from_bytes_store(\n",
" underlying_embedder, fs, namespace=underlying_embedder.model\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "630515fd-bf5c-4d9c-a404-9705308f3a2c",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 10.5 ms, sys: 988 µs, total: 11.5 ms\n",
"Wall time: 220 ms\n"
]
}
],
"source": [
"%%time\n",
"embeddings = embedder2.embed_documents([\"hello\", \"goodbye\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "30e6bb87-42c9-4d08-88ac-0d22c9c449a1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.49 ms, sys: 0 ns, total: 3.49 ms\n",
"Wall time: 3.03 ms\n"
]
}
],
"source": [
"%%time\n",
"embeddings = embedder2.embed_documents([\"hello\", \"goodbye\"])"
]
},
{
"cell_type": "markdown",
"id": "12ed5a45-8352-4e0f-8583-5537397f53c0",
"metadata": {},
"source": [
"Here are the embeddings that have been persisted to the directory `./cache`. \n",
"\n",
"Notice that the embedder takes a namespace parameter."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "658e2914-05e9-44a3-a8fe-3fe17ca84039",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['text-embedding-ada-002e885db5b-c0bd-5fbc-88b1-4d1da6020aa5',\n",
" 'text-embedding-ada-0026ba52e44-59c9-5cc9-a084-284061b13c80']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(fs.yield_keys())"
]
},
{
"cell_type": "markdown",
"id": "c67f8e97-4851-4e26-ab6f-3418b0188dc4",
"metadata": {},
"source": [
"## Using with a vectorstore\n",
"\n",
"Let's see this cache in action with the FAISS vectorstore."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9e4314d8-88ef-4f52-81ae-0be771168bb6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import FAISS"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "30743664-38f5-425d-8216-772b64e7f348",
"metadata": {},
"outputs": [],
"source": [
"fs = LocalFileStore(\"./cache/\")\n",
"\n",
"cached_embedder = CacheBackedEmbedder.from_bytes_store(\n",
" OpenAIEmbeddings(), fs, namespace=underlying_embedder.model\n",
")"
]
},
{
"cell_type": "markdown",
"id": "06a6f305-724f-4b71-adef-be0169f61381",
"metadata": {},
"source": [
"The cache is empty prior to embedding"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f9ad627f-ced2-4277-b336-2434f22f2c8a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['text-embedding-ada-002e885db5b-c0bd-5fbc-88b1-4d1da6020aa5',\n",
" 'text-embedding-ada-0026ba52e44-59c9-5cc9-a084-284061b13c80']"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(fs.yield_keys())"
]
},
{
"cell_type": "markdown",
"id": "5814aa9c-e8e4-4079-accf-53c49615971e",
"metadata": {},
"source": [
"Load the document, split it into chunks, embed each chunk and load it into the vector store."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "cf958ac2-e60e-4668-b32c-8bb2d78b3c61",
"metadata": {},
"outputs": [],
"source": [
"raw_documents = TextLoader(\"../state_of_the_union.txt\").load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"documents = text_splitter.split_documents(raw_documents)"
]
},
{
"cell_type": "markdown",
"id": "fc433fec-ab64-4f11-ae8b-fc3dd76cd79a",
"metadata": {},
"source": [
"create the vectorstore"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3a1d7bb8-3b72-4bb5-9013-cf7729caca61",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 124 ms, sys: 22.6 ms, total: 146 ms\n",
"Wall time: 832 ms\n"
]
}
],
"source": [
"%%time\n",
"db = FAISS.from_documents(documents, cached_embedder)"
]
},
{
"cell_type": "markdown",
"id": "c94a734c-fa66-40ce-8610-12b00b7df334",
"metadata": {},
"source": [
"If we try to create the vectostore again, it'll be much faster since it does not need to re-compute any embeddings."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "714cb2e2-77ba-41a8-bb83-84e75342af2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 32.9 ms, sys: 286 µs, total: 33.2 ms\n",
"Wall time: 32.5 ms\n"
]
}
],
"source": [
"%%time\n",
"db2 = FAISS.from_documents(documents, cached_embedder)"
]
},
{
"cell_type": "markdown",
"id": "93d37b2a-5406-4e2c-b786-869e2430d19d",
"metadata": {},
"source": [
"And here are some of the embeddings that got created:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "f2ca32dd-3712-4093-942b-4122f3dc8a8e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['text-embedding-ada-002614d7cf6-46f1-52fa-9d3a-740c39e7a20e',\n",
" 'text-embedding-ada-0020fc1ede2-407a-5e14-8f8f-5642214263f5',\n",
" 'text-embedding-ada-002e885db5b-c0bd-5fbc-88b1-4d1da6020aa5',\n",
" 'text-embedding-ada-002e4ad20ef-dfaa-5916-9459-f90c6d8e8159',\n",
" 'text-embedding-ada-002a5ef11e4-0474-5725-8d80-81c91943b37f']"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(fs.yield_keys())[:5]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -20,6 +20,7 @@ from langchain.embeddings.aleph_alpha import (
)
from langchain.embeddings.awa import AwaEmbeddings
from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.embeddings.clarifai import ClarifaiEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.dashscope import DashScopeEmbeddings
@@ -62,10 +63,11 @@ logger = logging.getLogger(__name__)
__all__ = [
"OpenAIEmbeddings",
"HuggingFaceEmbeddings",
"CohereEmbeddings",
"CacheBackedEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"ElasticsearchEmbeddings",
"HuggingFaceEmbeddings",
"JinaEmbeddings",
"LlamaCppEmbeddings",
"HuggingFaceHubEmbeddings",

View File

@@ -0,0 +1,175 @@
"""Module contains code for a cache backed embedder.
The cache backed embedder is a wrapper around an embedder that caches
embeddings in a key-value store. The cache is used to avoid recomputing
embeddings for the same text.
The text is hashed and the hash is used as the key in the cache.
"""
from __future__ import annotations
import hashlib
import json
import uuid
from functools import partial
from typing import Callable, List, Sequence, Union, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseStore
from langchain.storage.encoder_backed import EncoderBackedStore
NAMESPACE_UUID = uuid.UUID(int=1985)
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
"""Hash a string and returns the corresponding UUID."""
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _key_encoder(key: str, namespace: str) -> str:
"""Encode a key."""
return namespace + str(_hash_string_to_uuid(key))
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
"""Create an encoder for a key."""
return partial(_key_encoder, namespace=namespace)
def _value_serializer(value: Sequence[float]) -> bytes:
"""Serialize a value."""
return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> List[float]:
"""Deserialize a value."""
return cast(List[float], json.loads(serialized_value.decode()))
class CacheBackedEmbeddings(Embeddings):
"""Interface for caching results from embedding models.
The interface allows works with any store that implements
the abstract store interface accepting keys of type str and values of list of
floats.
If need be, the interface can be extended to accept other implementations
of the value serializer and deserializer, as well as the key encoder.
Examples:
..code-block:: python
from langchain.embeddings import CacheBackedEmbeddings, OpenAIEmbeddings
from langchain.storage import LocalFileStore
store = LocalFileStore('./my_cache')
underlying_embedder = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embedder, store, namespace=underlying_embedder.model
)
# Embedding is computed and cached
embeddings = embedder.embed_documents(["hello", "goodbye"])
# Embeddings are retrieved from the cache, no computation is done
embeddings = embedder.embed_documents(["hello", "goodbye"])
"""
def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]],
) -> None:
"""Initialize the embedder.
Args:
underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.underlying_embeddings = underlying_embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
texts
)
missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
missing_texts = [texts[i] for i in missing_indices]
if missing_texts:
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
def embed_query(self, text: str) -> List[float]:
"""Embed query text.
This method does not support caching at the moment.
Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
return self.underlying_embeddings.embed_query(text)
@classmethod
def from_bytes_store(
cls,
underlying_embeddings: Embeddings,
document_embedding_cache: BaseStore[str, bytes],
*,
namespace: str = "",
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
Args:
underlying_embeddings: The embedder to use for embedding.
document_embedding_cache: The cache to use for storing document embeddings.
*,
namespace: The namespace to use for document cache.
This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used.
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
encoder_backed_store = EncoderBackedStore[str, List[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(underlying_embeddings, encoder_backed_store)

View File

@@ -0,0 +1,49 @@
"""Embeddings tests."""
from typing import List
import pytest
from langchain.embeddings import CacheBackedEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.storage.in_memory import InMemoryStore
class MockEmbeddings(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
# Simulate embedding documents
embeddings: List[List[float]] = []
for text in texts:
embeddings.append([len(text), len(text) + 1])
return embeddings
def embed_query(self, text: str) -> List[float]:
# Simulate embedding a query
return [5.0, 6.0]
@pytest.fixture
def cache_embeddings() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings, store, namespace="test_namespace"
)
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts)
expected_vectors: List[List[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = list(cache_embeddings.document_embedding_store.yield_keys())
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector