mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 14:36:54 +00:00
FEATURE: Astra DB, LLM cache classes (exact-match and semantic cache) (#13834)
This PR provides idiomatic implementations for the exact-match and the semantic LLM caches using Astra DB as backend through the database's HTTP JSON API. These caches require the `astrapy` library as dependency. Comes with integration tests and example usage in the `llm_cache.ipynb` in the docs. @baskaryan this is the Astra DB counterpart for the Cassandra classes you merged some time ago, tagging you for your familiarity with the topic. Thank you! --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1239,3 +1239,318 @@ class SQLAlchemyMd5Cache(BaseCache):
|
||||
@staticmethod
|
||||
def get_md5(input_string: str) -> str:
|
||||
return hashlib.md5(input_string.encode()).hexdigest()
|
||||
|
||||
|
||||
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache"
|
||||
|
||||
|
||||
class AstraDBCache(BaseCache):
|
||||
"""
|
||||
Cache that uses Astra DB as a backend.
|
||||
|
||||
It uses a single collection as a kv store
|
||||
The lookup keys, combined in the _id of the documents, are:
|
||||
- prompt, a string
|
||||
- llm_string, a deterministic str representation of the model parameters.
|
||||
(needed to prevent same-prompt-different-model collisions)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
namespace: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create an AstraDB cache using a collection for storage.
|
||||
|
||||
Args (only keyword-arguments accepted):
|
||||
collection_name (str): name of the Astra DB collection to create/use.
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
return f"{_hash(prompt)}#{_hash(llm_string)}"
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
item = self.collection.find_one(
|
||||
filter={
|
||||
"_id": doc_id,
|
||||
},
|
||||
projection={
|
||||
"body_blob": 1,
|
||||
},
|
||||
)["data"]["document"]
|
||||
if item is not None:
|
||||
generations = _loads_generations(item["body_blob"])
|
||||
# this protects against malformed cached items:
|
||||
if generations is not None:
|
||||
return generations
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
blob = _dumps_generations(return_val)
|
||||
self.collection.upsert(
|
||||
{
|
||||
"_id": doc_id,
|
||||
"body_blob": blob,
|
||||
},
|
||||
)
|
||||
|
||||
def delete_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
A wrapper around `delete` with the LLM being passed.
|
||||
In case the llm(prompt) calls have a `stop` param, you should pass it here
|
||||
"""
|
||||
llm_string = get_prompts(
|
||||
{**llm.dict(), **{"stop": stop}},
|
||||
[],
|
||||
)[1]
|
||||
return self.delete(prompt, llm_string=llm_string)
|
||||
|
||||
def delete(self, prompt: str, llm_string: str) -> None:
|
||||
"""Evict from cache if there's an entry."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
return self.collection.delete_one(doc_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
self.astra_db.truncate_collection(self.collection_name)
|
||||
|
||||
|
||||
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
|
||||
ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
|
||||
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
|
||||
|
||||
|
||||
class AstraDBSemanticCache(BaseCache):
|
||||
"""
|
||||
Cache that uses Astra DB as a vector-store backend for semantic
|
||||
(i.e. similarity-based) lookup.
|
||||
|
||||
It uses a single (vector) collection and can store
|
||||
cached values from several LLMs, so the LLM's 'llm_string' is stored
|
||||
in the document metadata.
|
||||
|
||||
You can choose the preferred similarity (or use the API default) --
|
||||
remember the threshold might require metric-dependend tuning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
namespace: Optional[str] = None,
|
||||
embedding: Embeddings,
|
||||
metric: Optional[str] = None,
|
||||
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize the cache with all relevant parameters.
|
||||
Args:
|
||||
|
||||
collection_name (str): name of the Astra DB collection to create/use.
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
embedding (Embedding): Embedding provider for semantic
|
||||
encoding and search.
|
||||
metric: the function to use for evaluating similarity of text embeddings.
|
||||
Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product')
|
||||
similarity_threshold (float, optional): the minimum similarity
|
||||
for accepting a (semantic-search) match.
|
||||
|
||||
The default score threshold is tuned to the default metric.
|
||||
Tune it carefully yourself if switching to another distance metric.
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.embedding = embedding
|
||||
self.metric = metric
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
# The contract for this class has separate lookup and update:
|
||||
# in order to spare some embedding calculations we cache them between
|
||||
# the two calls.
|
||||
# Note: each instance of this class has its own `_get_embedding` with
|
||||
# its own lru.
|
||||
@lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
||||
def _cache_embedding(text: str) -> List[float]:
|
||||
return self.embedding.embed_query(text=text)
|
||||
|
||||
self._get_embedding = _cache_embedding
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
dimension=self.embedding_dimension,
|
||||
metric=self.metric,
|
||||
)
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
return len(self._get_embedding(text="This is a sample sentence."))
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
return f"{_hash(prompt)}#{_hash(llm_string)}"
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
embedding_vector = self._get_embedding(text=prompt)
|
||||
body = _dumps_generations(return_val)
|
||||
#
|
||||
self.collection.upsert(
|
||||
{
|
||||
"_id": doc_id,
|
||||
"body_blob": body,
|
||||
"llm_string_hash": llm_string_hash,
|
||||
"$vector": embedding_vector,
|
||||
}
|
||||
)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
||||
if hit_with_id is not None:
|
||||
return hit_with_id[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def lookup_with_id(
|
||||
self, prompt: str, llm_string: str
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
"""
|
||||
Look up based on prompt and llm_string.
|
||||
If there are hits, return (document_id, cached_entry) for the top hit
|
||||
"""
|
||||
prompt_embedding: List[float] = self._get_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
|
||||
hit = self.collection.vector_find_one(
|
||||
vector=prompt_embedding,
|
||||
filter={
|
||||
"llm_string_hash": llm_string_hash,
|
||||
},
|
||||
fields=["body_blob", "_id"],
|
||||
include_similarity=True,
|
||||
)
|
||||
|
||||
if hit is None or hit["$similarity"] < self.similarity_threshold:
|
||||
return None
|
||||
else:
|
||||
generations = _loads_generations(hit["body_blob"])
|
||||
if generations is not None:
|
||||
# this protects against malformed cached items:
|
||||
return (hit["_id"], generations)
|
||||
else:
|
||||
return None
|
||||
|
||||
def lookup_with_id_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
llm_string = get_prompts(
|
||||
{**llm.dict(), **{"stop": stop}},
|
||||
[],
|
||||
)[1]
|
||||
return self.lookup_with_id(prompt, llm_string=llm_string)
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
"""
|
||||
Given this is a "similarity search" cache, an invalidation pattern
|
||||
that makes sense is first a lookup to get an ID, and then deleting
|
||||
with that ID. This is for the second step.
|
||||
"""
|
||||
self.collection.delete_one(document_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear the *whole* semantic cache."""
|
||||
self.astra_db.truncate_collection(self.collection_name)
|
||||
|
99
libs/langchain/tests/integration_tests/cache/test_astradb.py
vendored
Normal file
99
libs/langchain/tests/integration_tests/cache/test_astradb.py
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Test AstraDB caches. Requires an Astra DB vector instance.
|
||||
|
||||
Required to run this test:
|
||||
- a recent `astrapy` Python package available
|
||||
- an Astra DB instance;
|
||||
- the two environment variables set:
|
||||
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
||||
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
|
||||
- optionally this as well (otherwise defaults are used):
|
||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
||||
"""
|
||||
import os
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.cache import AstraDBCache, AstraDBSemanticCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def astradb_cache() -> Iterator[AstraDBCache]:
|
||||
cache = AstraDBCache(
|
||||
collection_name="lc_integration_test_cache",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield cache
|
||||
cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
|
||||
fake_embe = FakeEmbeddings()
|
||||
sem_cache = AstraDBSemanticCache(
|
||||
collection_name="lc_integration_test_sem_cache",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
embedding=fake_embe,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBCaches:
|
||||
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
|
||||
set_llm_cache(astradb_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output)
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output)
|
||||
assert output == expected_output
|
||||
astradb_cache.clear()
|
||||
|
||||
def test_cassandra_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
set_llm_cache(astradb_semantic_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
astradb_semantic_cache.clear()
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
astradb_semantic_cache.clear()
|
Reference in New Issue
Block a user