FEATURE: Astra DB, LLM cache classes (exact-match and semantic cache) (#13834)

This PR provides idiomatic implementations for the exact-match and the
semantic LLM caches using Astra DB as backend through the database's
HTTP JSON API. These caches require the `astrapy` library as dependency.

Comes with integration tests and example usage in the `llm_cache.ipynb`
in the docs.

@baskaryan this is the Astra DB counterpart for the Cassandra classes
you merged some time ago, tagging you for your familiarity with the
topic. Thank you!

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Stefano Lottini
2023-11-25 03:53:37 +01:00
committed by GitHub
parent 272df9dcae
commit 19c68c7652
4 changed files with 652 additions and 3 deletions

View File

@@ -1239,3 +1239,318 @@ class SQLAlchemyMd5Cache(BaseCache):
@staticmethod
def get_md5(input_string: str) -> str:
return hashlib.md5(input_string.encode()).hexdigest()
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
class AstraDBCache(BaseCache):
"""
Cache that uses Astra DB as a backend.
It uses a single collection as a kv store
The lookup keys, combined in the _id of the documents, are:
- prompt, a string
- llm_string, a deterministic str representation of the model parameters.
(needed to prevent same-prompt-different-model collisions)
"""
def __init__(
self,
*,
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
namespace: Optional[str] = None,
):
"""
Create an AstraDB cache using a collection for storage.
Args (only keyword-arguments accepted):
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
try:
from astrapy.db import (
AstraDB as LibAstraDB,
)
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(
collection_name=self.collection_name,
)
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
doc_id = self._make_id(prompt, llm_string)
item = self.collection.find_one(
filter={
"_id": doc_id,
},
projection={
"body_blob": 1,
},
)["data"]["document"]
if item is not None:
generations = _loads_generations(item["body_blob"])
# this protects against malformed cached items:
if generations is not None:
return generations
else:
return None
else:
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
self.collection.upsert(
{
"_id": doc_id,
"body_blob": blob,
},
)
def delete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
"""
A wrapper around `delete` with the LLM being passed.
In case the llm(prompt) calls have a `stop` param, you should pass it here
"""
llm_string = get_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)[1]
return self.delete(prompt, llm_string=llm_string)
def delete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry."""
doc_id = self._make_id(prompt, llm_string)
return self.collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
self.astra_db.truncate_collection(self.collection_name)
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
class AstraDBSemanticCache(BaseCache):
"""
Cache that uses Astra DB as a vector-store backend for semantic
(i.e. similarity-based) lookup.
It uses a single (vector) collection and can store
cached values from several LLMs, so the LLM's 'llm_string' is stored
in the document metadata.
You can choose the preferred similarity (or use the API default) --
remember the threshold might require metric-dependend tuning.
"""
def __init__(
self,
*,
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
namespace: Optional[str] = None,
embedding: Embeddings,
metric: Optional[str] = None,
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
):
"""
Initialize the cache with all relevant parameters.
Args:
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
embedding (Embedding): Embedding provider for semantic
encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
similarity_threshold (float, optional): the minimum similarity
for accepting a (semantic-search) match.
The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric.
"""
try:
from astrapy.db import (
AstraDB as LibAstraDB,
)
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.embedding = embedding
self.metric = metric
self.similarity_threshold = similarity_threshold
# The contract for this class has separate lookup and update:
# in order to spare some embedding calculations we cache them between
# the two calls.
# Note: each instance of this class has its own `_get_embedding` with
# its own lru.
@lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
def _cache_embedding(text: str) -> List[float]:
return self.embedding.embed_query(text=text)
self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(
collection_name=self.collection_name,
dimension=self.embedding_dimension,
metric=self.metric,
)
def _get_embedding_dimension(self) -> int:
return len(self._get_embedding(text="This is a sample sentence."))
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
embedding_vector = self._get_embedding(text=prompt)
body = _dumps_generations(return_val)
#
self.collection.upsert(
{
"_id": doc_id,
"body_blob": body,
"llm_string_hash": llm_string_hash,
"$vector": embedding_vector,
}
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
def lookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
"""
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry) for the top hit
"""
prompt_embedding: List[float] = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
hit = self.collection.vector_find_one(
vector=prompt_embedding,
filter={
"llm_string_hash": llm_string_hash,
},
fields=["body_blob", "_id"],
include_similarity=True,
)
if hit is None or hit["$similarity"] < self.similarity_threshold:
return None
else:
generations = _loads_generations(hit["body_blob"])
if generations is not None:
# this protects against malformed cached items:
return (hit["_id"], generations)
else:
return None
def lookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
llm_string = get_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)[1]
return self.lookup_with_id(prompt, llm_string=llm_string)
def delete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
self.collection.delete_one(document_id)
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.astra_db.truncate_collection(self.collection_name)

View File

@@ -0,0 +1,99 @@
"""
Test AstraDB caches. Requires an Astra DB vector instance.
Required to run this test:
- a recent `astrapy` Python package available
- an Astra DB instance;
- the two environment variables set:
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
- optionally this as well (otherwise defaults are used):
export ASTRA_DB_KEYSPACE="my_keyspace"
"""
import os
from typing import Iterator
import pytest
from langchain_core.outputs import Generation, LLMResult
from langchain.cache import AstraDBCache, AstraDBSemanticCache
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="module")
def astradb_cache() -> Iterator[AstraDBCache]:
cache = AstraDBCache(
collection_name="lc_integration_test_cache",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield cache
cache.astra_db.delete_collection("lc_integration_test_cache")
@pytest.fixture(scope="module")
def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
fake_embe = FakeEmbeddings()
sem_cache = AstraDBSemanticCache(
collection_name="lc_integration_test_sem_cache",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
embedding=fake_embe,
)
yield sem_cache
sem_cache.astra_db.delete_collection("lc_integration_test_cache")
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBCaches:
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
set_llm_cache(astradb_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
astradb_cache.clear()
def test_cassandra_semantic_cache(
self, astradb_semantic_cache: AstraDBSemanticCache
) -> None:
set_llm_cache(astradb_semantic_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
astradb_semantic_cache.clear()
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
astradb_semantic_cache.clear()