community: Add async methods to AstraDBCache (#17415)

Adds async methods to AstraDBCache
This commit is contained in:
Christophe Bornet
2024-02-15 05:10:08 +01:00
committed by GitHub
parent e438fe6be9
commit ca2d4078f3
6 changed files with 465 additions and 106 deletions

View File

@@ -29,12 +29,14 @@ import uuid
import warnings
from abc import ABC
from datetime import timedelta
from functools import lru_cache
from functools import lru_cache, wraps
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Sequence,
@@ -56,20 +58,23 @@ except ImportError:
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, get_prompts
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env
from langchain_community.utilities.astradb import AstraDBEnvironment
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__)
if TYPE_CHECKING:
import momento
from astrapy.db import AstraDB
from astrapy.db import AstraDB, AsyncAstraDB
from cassandra.cluster import Session as CassandraSession
@@ -1371,6 +1376,10 @@ class AstraDBCache(BaseCache):
(needed to prevent same-prompt-different-model collisions)
"""
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def __init__(
self,
*,
@@ -1378,7 +1387,10 @@ class AstraDBCache(BaseCache):
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
):
"""
Create an AstraDB cache using a collection for storage.
@@ -1388,29 +1400,35 @@ class AstraDBCache(BaseCache):
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
astra_db_client (Optional[AstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
async_setup (bool): whether to create the collection asynchronously.
Enable only if there is a running asyncio event loop. Defaults to False.
"""
astra_env = AstraDBEnvironment(
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
self.astra_db = astra_env.astra_db
self.collection = self.astra_db.create_collection(
collection_name=collection_name,
)
self.collection_name = collection_name
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
item = self.collection.find_one(
filter={
@@ -1420,18 +1438,27 @@ class AstraDBCache(BaseCache):
"body_blob": 1,
},
)["data"]["document"]
if item is not None:
generations = _loads_generations(item["body_blob"])
# this protects against malformed cached items:
if generations is not None:
return generations
else:
return None
else:
return None
return _loads_generations(item["body_blob"]) if item is not None else None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
item = (
await self.async_collection.find_one(
filter={
"_id": doc_id,
},
projection={
"body_blob": 1,
},
)
)["data"]["document"]
return _loads_generations(item["body_blob"]) if item is not None else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
self.collection.upsert(
@@ -1441,6 +1468,20 @@ class AstraDBCache(BaseCache):
},
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
await self.async_collection.upsert(
{
"_id": doc_id,
"body_blob": blob,
},
)
def delete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
@@ -1454,14 +1495,42 @@ class AstraDBCache(BaseCache):
)[1]
return self.delete(prompt, llm_string=llm_string)
async def adelete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
"""
A wrapper around `adelete` with the LLM being passed.
In case the llm(prompt) calls have a `stop` param, you should pass it here
"""
llm_string = (
await aget_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)
)[1]
return await self.adelete(prompt, llm_string=llm_string)
def delete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
self.collection.delete_one(doc_id)
async def adelete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
await self.async_collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
self.astra_db.truncate_collection(self.collection_name)
self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
@@ -1469,6 +1538,42 @@ ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
_unset = ["unset"]
class _CachedAwaitable:
"""Caches the result of an awaitable so it can be awaited multiple times"""
def __init__(self, awaitable: Awaitable[Any]):
self.awaitable = awaitable
self.result = _unset
def __await__(self) -> Generator:
if self.result is _unset:
self.result = yield from self.awaitable.__await__()
return self.result
def _reawaitable(func: Callable) -> Callable:
"""Makes an async function result awaitable multiple times"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> _CachedAwaitable:
return _CachedAwaitable(func(*args, **kwargs))
return wrapper
def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
"""Least-recently-used async cache decorator.
Equivalent to functools.lru_cache for async functions"""
def decorating_function(user_function: Callable) -> Callable:
return lru_cache(maxsize, typed)(_reawaitable(user_function))
return decorating_function
class AstraDBSemanticCache(BaseCache):
"""
Cache that uses Astra DB as a vector-store backend for semantic
@@ -1479,7 +1584,7 @@ class AstraDBSemanticCache(BaseCache):
in the document metadata.
You can choose the preferred similarity (or use the API default) --
remember the threshold might require metric-dependend tuning.
remember the threshold might require metric-dependent tuning.
"""
def __init__(
@@ -1489,7 +1594,10 @@ class AstraDBSemanticCache(BaseCache):
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding: Embeddings,
metric: Optional[str] = None,
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
@@ -1502,10 +1610,17 @@ class AstraDBSemanticCache(BaseCache):
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode (SetupMode): mode used to create the collection in the DB
(SYNC, ASYNC or OFF). Defaults to SYNC.
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
embedding (Embedding): Embedding provider for semantic
encoding and search.
metric: the function to use for evaluating similarity of text embeddings.
@@ -1516,17 +1631,10 @@ class AstraDBSemanticCache(BaseCache):
The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric.
"""
astra_env = AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
namespace=namespace,
)
self.astra_db = astra_env.astra_db
self.embedding = embedding
self.metric = metric
self.similarity_threshold = similarity_threshold
self.collection_name = collection_name
# The contract for this class has separate lookup and update:
# in order to spare some embedding calculations we cache them between
@@ -1538,25 +1646,47 @@ class AstraDBSemanticCache(BaseCache):
return self.embedding.embed_query(text=text)
self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
self.collection_name = collection_name
@_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
async def _acache_embedding(text: str) -> List[float]:
return await self.embedding.aembed_query(text=text)
self.collection = self.astra_db.create_collection(
collection_name=self.collection_name,
dimension=self.embedding_dimension,
metric=self.metric,
self._aget_embedding = _acache_embedding
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == SetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
def _get_embedding_dimension(self) -> int:
return len(self._get_embedding(text="This is a sample sentence."))
async def _aget_embedding_dimension(self) -> int:
return len(await self._aget_embedding(text="This is a sample sentence."))
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
embedding_vector = self._get_embedding(text=prompt)
@@ -1571,6 +1701,25 @@ class AstraDBSemanticCache(BaseCache):
}
)
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
embedding_vector = await self._aget_embedding(text=prompt)
body = _dumps_generations(return_val)
#
await self.async_collection.upsert(
{
"_id": doc_id,
"body_blob": body,
"llm_string_hash": llm_string_hash,
"$vector": embedding_vector,
}
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string)
@@ -1579,6 +1728,14 @@ class AstraDBSemanticCache(BaseCache):
else:
return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = await self.alookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
def lookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
@@ -1586,6 +1743,7 @@ class AstraDBSemanticCache(BaseCache):
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry) for the top hit
"""
self.astra_env.ensure_db_setup()
prompt_embedding: List[float] = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
@@ -1604,7 +1762,37 @@ class AstraDBSemanticCache(BaseCache):
generations = _loads_generations(hit["body_blob"])
if generations is not None:
# this protects against malformed cached items:
return (hit["_id"], generations)
return hit["_id"], generations
else:
return None
async def alookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
"""
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry) for the top hit
"""
await self.astra_env.aensure_db_setup()
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
hit = await self.async_collection.vector_find_one(
vector=prompt_embedding,
filter={
"llm_string_hash": llm_string_hash,
},
fields=["body_blob", "_id"],
include_similarity=True,
)
if hit is None or hit["$similarity"] < self.similarity_threshold:
return None
else:
generations = _loads_generations(hit["body_blob"])
if generations is not None:
# this protects against malformed cached items:
return hit["_id"], generations
else:
return None
@@ -1617,14 +1805,41 @@ class AstraDBSemanticCache(BaseCache):
)[1]
return self.lookup_with_id(prompt, llm_string=llm_string)
async def alookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
llm_string = (
await aget_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)
)[1]
return await self.alookup_with_id(prompt, llm_string=llm_string)
def delete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
self.astra_env.ensure_db_setup()
self.collection.delete_one(document_id)
async def adelete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_one(document_id)
def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.astra_db.truncate_collection(self.collection_name)
self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()