community: Add async methods to AstraDBCache (#17415)

Adds async methods to AstraDBCache
This commit is contained in:
Christophe Bornet 2024-02-15 05:10:08 +01:00 committed by GitHub
parent e438fe6be9
commit ca2d4078f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 465 additions and 106 deletions

View File

@ -29,12 +29,14 @@ import uuid
import warnings import warnings
from abc import ABC from abc import ABC
from datetime import timedelta from datetime import timedelta
from functools import lru_cache from functools import lru_cache, wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Generator,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -56,20 +58,23 @@ except ImportError:
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, get_prompts from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
from langchain_core.load.dump import dumps from langchain_core.load.dump import dumps
from langchain_core.load.load import loads from langchain_core.load.load import loads
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env from langchain_core.utils import get_from_env
from langchain_community.utilities.astradb import AstraDBEnvironment from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.redis import Redis as RedisVectorstore from langchain_community.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
if TYPE_CHECKING: if TYPE_CHECKING:
import momento import momento
from astrapy.db import AstraDB from astrapy.db import AstraDB, AsyncAstraDB
from cassandra.cluster import Session as CassandraSession from cassandra.cluster import Session as CassandraSession
@ -1371,6 +1376,10 @@ class AstraDBCache(BaseCache):
(needed to prevent same-prompt-different-model collisions) (needed to prevent same-prompt-different-model collisions)
""" """
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def __init__( def __init__(
self, self,
*, *,
@ -1378,7 +1387,10 @@ class AstraDBCache(BaseCache):
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None, astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
): ):
""" """
Create an AstraDB cache using a collection for storage. Create an AstraDB cache using a collection for storage.
@ -1388,29 +1400,35 @@ class AstraDBCache(BaseCache):
token (Optional[str]): API token for Astra DB usage. token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint, api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, astra_db_client (Optional[AstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance. you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace". collection is created. Defaults to the database's "default namespace".
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
async_setup (bool): whether to create the collection asynchronously.
Enable only if there is a running asyncio event loop. Defaults to False.
""" """
astra_env = AstraDBEnvironment( self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token, token=token,
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
astra_db_client=astra_db_client, astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace, namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
) )
self.astra_db = astra_env.astra_db self.collection = self.astra_env.collection
self.collection = self.astra_db.create_collection( self.async_collection = self.astra_env.async_collection
collection_name=collection_name,
)
self.collection_name = collection_name
@staticmethod
def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}"
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
item = self.collection.find_one( item = self.collection.find_one(
filter={ filter={
@ -1420,18 +1438,27 @@ class AstraDBCache(BaseCache):
"body_blob": 1, "body_blob": 1,
}, },
)["data"]["document"] )["data"]["document"]
if item is not None: return _loads_generations(item["body_blob"]) if item is not None else None
generations = _loads_generations(item["body_blob"])
# this protects against malformed cached items: async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
if generations is not None: """Look up based on prompt and llm_string."""
return generations await self.astra_env.aensure_db_setup()
else: doc_id = self._make_id(prompt, llm_string)
return None item = (
else: await self.async_collection.find_one(
return None filter={
"_id": doc_id,
},
projection={
"body_blob": 1,
},
)
)["data"]["document"]
return _loads_generations(item["body_blob"]) if item is not None else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val) blob = _dumps_generations(return_val)
self.collection.upsert( self.collection.upsert(
@ -1441,6 +1468,20 @@ class AstraDBCache(BaseCache):
}, },
) )
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
blob = _dumps_generations(return_val)
await self.async_collection.upsert(
{
"_id": doc_id,
"body_blob": blob,
},
)
def delete_through_llm( def delete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None: ) -> None:
@ -1454,14 +1495,42 @@ class AstraDBCache(BaseCache):
)[1] )[1]
return self.delete(prompt, llm_string=llm_string) return self.delete(prompt, llm_string=llm_string)
async def adelete_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> None:
"""
A wrapper around `adelete` with the LLM being passed.
In case the llm(prompt) calls have a `stop` param, you should pass it here
"""
llm_string = (
await aget_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)
)[1]
return await self.adelete(prompt, llm_string=llm_string)
def delete(self, prompt: str, llm_string: str) -> None: def delete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry.""" """Evict from cache if there's an entry."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
self.collection.delete_one(doc_id) self.collection.delete_one(doc_id)
async def adelete(self, prompt: str, llm_string: str) -> None:
"""Evict from cache if there's an entry."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
await self.async_collection.delete_one(doc_id)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once.""" """Clear cache. This is for all LLMs at once."""
self.astra_db.truncate_collection(self.collection_name) self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear cache. This is for all LLMs at once."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
@ -1469,6 +1538,42 @@ ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
_unset = ["unset"]
class _CachedAwaitable:
"""Caches the result of an awaitable so it can be awaited multiple times"""
def __init__(self, awaitable: Awaitable[Any]):
self.awaitable = awaitable
self.result = _unset
def __await__(self) -> Generator:
if self.result is _unset:
self.result = yield from self.awaitable.__await__()
return self.result
def _reawaitable(func: Callable) -> Callable:
"""Makes an async function result awaitable multiple times"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> _CachedAwaitable:
return _CachedAwaitable(func(*args, **kwargs))
return wrapper
def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
"""Least-recently-used async cache decorator.
Equivalent to functools.lru_cache for async functions"""
def decorating_function(user_function: Callable) -> Callable:
return lru_cache(maxsize, typed)(_reawaitable(user_function))
return decorating_function
class AstraDBSemanticCache(BaseCache): class AstraDBSemanticCache(BaseCache):
""" """
Cache that uses Astra DB as a vector-store backend for semantic Cache that uses Astra DB as a vector-store backend for semantic
@ -1479,7 +1584,7 @@ class AstraDBSemanticCache(BaseCache):
in the document metadata. in the document metadata.
You can choose the preferred similarity (or use the API default) -- You can choose the preferred similarity (or use the API default) --
remember the threshold might require metric-dependend tuning. remember the threshold might require metric-dependent tuning.
""" """
def __init__( def __init__(
@ -1489,7 +1594,10 @@ class AstraDBSemanticCache(BaseCache):
token: Optional[str] = None, token: Optional[str] = None,
api_endpoint: Optional[str] = None, api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None, astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding: Embeddings, embedding: Embeddings,
metric: Optional[str] = None, metric: Optional[str] = None,
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
@ -1502,10 +1610,17 @@ class AstraDBSemanticCache(BaseCache):
token (Optional[str]): API token for Astra DB usage. token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint, api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance. you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client (Optional[AsyncAstraDB]):
*alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace". collection is created. Defaults to the database's "default namespace".
setup_mode (SetupMode): mode used to create the collection in the DB
(SYNC, ASYNC or OFF). Defaults to SYNC.
pre_delete_collection (bool): whether to delete and re-create the
collection. Defaults to False.
embedding (Embedding): Embedding provider for semantic embedding (Embedding): Embedding provider for semantic
encoding and search. encoding and search.
metric: the function to use for evaluating similarity of text embeddings. metric: the function to use for evaluating similarity of text embeddings.
@ -1516,17 +1631,10 @@ class AstraDBSemanticCache(BaseCache):
The default score threshold is tuned to the default metric. The default score threshold is tuned to the default metric.
Tune it carefully yourself if switching to another distance metric. Tune it carefully yourself if switching to another distance metric.
""" """
astra_env = AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
namespace=namespace,
)
self.astra_db = astra_env.astra_db
self.embedding = embedding self.embedding = embedding
self.metric = metric self.metric = metric
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.collection_name = collection_name
# The contract for this class has separate lookup and update: # The contract for this class has separate lookup and update:
# in order to spare some embedding calculations we cache them between # in order to spare some embedding calculations we cache them between
@ -1538,25 +1646,47 @@ class AstraDBSemanticCache(BaseCache):
return self.embedding.embed_query(text=text) return self.embedding.embed_query(text=text)
self._get_embedding = _cache_embedding self._get_embedding = _cache_embedding
self.embedding_dimension = self._get_embedding_dimension()
self.collection_name = collection_name @_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
async def _acache_embedding(text: str) -> List[float]:
return await self.embedding.aembed_query(text=text)
self.collection = self.astra_db.create_collection( self._aget_embedding = _acache_embedding
collection_name=self.collection_name,
dimension=self.embedding_dimension, embedding_dimension: Union[int, Awaitable[int], None] = None
metric=self.metric, if setup_mode == SetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
) )
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
def _get_embedding_dimension(self) -> int: def _get_embedding_dimension(self) -> int:
return len(self._get_embedding(text="This is a sample sentence.")) return len(self._get_embedding(text="This is a sample sentence."))
async def _aget_embedding_dimension(self) -> int:
return len(await self._aget_embedding(text="This is a sample sentence."))
@staticmethod @staticmethod
def _make_id(prompt: str, llm_string: str) -> str: def _make_id(prompt: str, llm_string: str) -> str:
return f"{_hash(prompt)}#{_hash(llm_string)}" return f"{_hash(prompt)}#{_hash(llm_string)}"
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
self.astra_env.ensure_db_setup()
doc_id = self._make_id(prompt, llm_string) doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string) llm_string_hash = _hash(llm_string)
embedding_vector = self._get_embedding(text=prompt) embedding_vector = self._get_embedding(text=prompt)
@ -1571,6 +1701,25 @@ class AstraDBSemanticCache(BaseCache):
} }
) )
async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
await self.astra_env.aensure_db_setup()
doc_id = self._make_id(prompt, llm_string)
llm_string_hash = _hash(llm_string)
embedding_vector = await self._aget_embedding(text=prompt)
body = _dumps_generations(return_val)
#
await self.async_collection.upsert(
{
"_id": doc_id,
"body_blob": body,
"llm_string_hash": llm_string_hash,
"$vector": embedding_vector,
}
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
hit_with_id = self.lookup_with_id(prompt, llm_string) hit_with_id = self.lookup_with_id(prompt, llm_string)
@ -1579,6 +1728,14 @@ class AstraDBSemanticCache(BaseCache):
else: else:
return None return None
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
hit_with_id = await self.alookup_with_id(prompt, llm_string)
if hit_with_id is not None:
return hit_with_id[1]
else:
return None
def lookup_with_id( def lookup_with_id(
self, prompt: str, llm_string: str self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
@ -1586,6 +1743,7 @@ class AstraDBSemanticCache(BaseCache):
Look up based on prompt and llm_string. Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry) for the top hit If there are hits, return (document_id, cached_entry) for the top hit
""" """
self.astra_env.ensure_db_setup()
prompt_embedding: List[float] = self._get_embedding(text=prompt) prompt_embedding: List[float] = self._get_embedding(text=prompt)
llm_string_hash = _hash(llm_string) llm_string_hash = _hash(llm_string)
@ -1604,7 +1762,37 @@ class AstraDBSemanticCache(BaseCache):
generations = _loads_generations(hit["body_blob"]) generations = _loads_generations(hit["body_blob"])
if generations is not None: if generations is not None:
# this protects against malformed cached items: # this protects against malformed cached items:
return (hit["_id"], generations) return hit["_id"], generations
else:
return None
async def alookup_with_id(
self, prompt: str, llm_string: str
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
"""
Look up based on prompt and llm_string.
If there are hits, return (document_id, cached_entry) for the top hit
"""
await self.astra_env.aensure_db_setup()
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
llm_string_hash = _hash(llm_string)
hit = await self.async_collection.vector_find_one(
vector=prompt_embedding,
filter={
"llm_string_hash": llm_string_hash,
},
fields=["body_blob", "_id"],
include_similarity=True,
)
if hit is None or hit["$similarity"] < self.similarity_threshold:
return None
else:
generations = _loads_generations(hit["body_blob"])
if generations is not None:
# this protects against malformed cached items:
return hit["_id"], generations
else: else:
return None return None
@ -1617,14 +1805,41 @@ class AstraDBSemanticCache(BaseCache):
)[1] )[1]
return self.lookup_with_id(prompt, llm_string=llm_string) return self.lookup_with_id(prompt, llm_string=llm_string)
async def alookup_with_id_through_llm(
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
llm_string = (
await aget_prompts(
{**llm.dict(), **{"stop": stop}},
[],
)
)[1]
return await self.alookup_with_id(prompt, llm_string=llm_string)
def delete_by_document_id(self, document_id: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
""" """
Given this is a "similarity search" cache, an invalidation pattern Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step. with that ID. This is for the second step.
""" """
self.astra_env.ensure_db_setup()
self.collection.delete_one(document_id) self.collection.delete_one(document_id)
async def adelete_by_document_id(self, document_id: str) -> None:
"""
Given this is a "similarity search" cache, an invalidation pattern
that makes sense is first a lookup to get an ID, and then deleting
with that ID. This is for the second step.
"""
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_one(document_id)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache.""" """Clear the *whole* semantic cache."""
self.astra_db.truncate_collection(self.collection_name) self.astra_env.ensure_db_setup()
self.collection.clear()
async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()

View File

@ -5,7 +5,7 @@ import json
import time import time
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from langchain_community.utilities.astradb import AstraDBEnvironment from langchain_community.utilities.astradb import _AstraDBEnvironment
if TYPE_CHECKING: if TYPE_CHECKING:
from astrapy.db import AstraDB from astrapy.db import AstraDB
@ -47,7 +47,7 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> None: ) -> None:
"""Create an Astra DB chat message history.""" """Create an Astra DB chat message history."""
astra_env = AstraDBEnvironment( astra_env = _AstraDBEnvironment(
token=token, token=token,
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
astra_db_client=astra_db_client, astra_db_client=astra_db_client,

View File

@ -19,7 +19,7 @@ from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.astradb import AstraDBEnvironment from langchain_community.utilities.astradb import _AstraDBEnvironment
if TYPE_CHECKING: if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB from astrapy.db import AstraDB, AsyncAstraDB
@ -44,7 +44,7 @@ class AstraDBLoader(BaseLoader):
nb_prefetched: int = 1000, nb_prefetched: int = 1000,
extraction_function: Callable[[Dict], str] = json.dumps, extraction_function: Callable[[Dict], str] = json.dumps,
) -> None: ) -> None:
astra_env = AstraDBEnvironment( astra_env = _AstraDBEnvironment(
token=token, token=token,
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
astra_db_client=astra_db_client, astra_db_client=astra_db_client,

View File

@ -16,7 +16,7 @@ from typing import (
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import AstraDBEnvironment from langchain_community.utilities.astradb import _AstraDBEnvironment
if TYPE_CHECKING: if TYPE_CHECKING:
from astrapy.db import AstraDB from astrapy.db import AstraDB
@ -35,7 +35,7 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
astra_db_client: Optional[AstraDB] = None, astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
) -> None: ) -> None:
astra_env = AstraDBEnvironment( astra_env = _AstraDBEnvironment(
token=token, token=token,
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
astra_db_client=astra_db_client, astra_db_client=astra_db_client,

View File

@ -1,6 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import TYPE_CHECKING, Awaitable, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from astrapy.db import ( from astrapy.db import (
@ -9,7 +13,13 @@ if TYPE_CHECKING:
) )
class AstraDBEnvironment: class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3
class _AstraDBEnvironment:
def __init__( def __init__(
self, self,
token: Optional[str] = None, token: Optional[str] = None,
@ -21,21 +31,20 @@ class AstraDBEnvironment:
self.token = token self.token = token
self.api_endpoint = api_endpoint self.api_endpoint = api_endpoint
astra_db = astra_db_client astra_db = astra_db_client
self.async_astra_db = async_astra_db_client async_astra_db = async_astra_db_client
self.namespace = namespace self.namespace = namespace
from astrapy import db
try: try:
from astrapy.db import AstraDB from astrapy.db import (
AstraDB,
AsyncAstraDB,
)
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
raise ImportError( raise ImportError(
"Could not import a recent astrapy python package. " "Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`." "Please install it with `pip install --upgrade astrapy`."
) )
supports_async = hasattr(db, "AsyncAstraDB")
# Conflicting-arg checks: # Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None: if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None: if token is not None or api_endpoint is not None:
@ -46,21 +55,30 @@ class AstraDBEnvironment:
if token and api_endpoint: if token and api_endpoint:
astra_db = AstraDB( astra_db = AstraDB(
token=self.token, token=token,
api_endpoint=self.api_endpoint, api_endpoint=api_endpoint,
namespace=self.namespace, namespace=self.namespace,
) )
if supports_async: async_astra_db = AsyncAstraDB(
self.async_astra_db = db.AsyncAstraDB( token=token,
token=self.token, api_endpoint=api_endpoint,
api_endpoint=self.api_endpoint,
namespace=self.namespace, namespace=self.namespace,
) )
if astra_db: if astra_db:
self.astra_db = astra_db self.astra_db = astra_db
if async_astra_db:
self.async_astra_db = async_astra_db
else: else:
if self.async_astra_db: self.async_astra_db = AsyncAstraDB(
token=self.astra_db.token,
api_endpoint=self.astra_db.base_url,
api_path=self.astra_db.api_path,
api_version=self.astra_db.api_version,
namespace=self.astra_db.namespace,
)
elif async_astra_db:
self.async_astra_db = async_astra_db
self.astra_db = AstraDB( self.astra_db = AstraDB(
token=self.async_astra_db.token, token=self.async_astra_db.token,
api_endpoint=self.async_astra_db.base_url, api_endpoint=self.async_astra_db.base_url,
@ -74,11 +92,78 @@ class AstraDBEnvironment:
"'token' and 'api_endpoint'" "'token' and 'api_endpoint'"
) )
if not self.async_astra_db and self.astra_db and supports_async:
self.async_astra_db = db.AsyncAstraDB( class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
token=self.astra_db.token, def __init__(
api_endpoint=self.astra_db.base_url, self,
api_path=self.astra_db.api_path, collection_name: str,
api_version=self.astra_db.api_version, token: Optional[str] = None,
namespace=self.astra_db.namespace, api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
) )
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)
self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db
async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)
self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
)
def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

View File

@ -12,9 +12,12 @@ Required to run this test:
""" """
import os import os
from typing import Iterator from typing import AsyncIterator, Iterator
import pytest import pytest
from langchain_community.utilities.astradb import SetupMode
from langchain_core.caches import BaseCache
from langchain_core.language_models import LLM
from langchain_core.outputs import Generation, LLMResult from langchain_core.outputs import Generation, LLMResult
from langchain.cache import AstraDBCache, AstraDBSemanticCache from langchain.cache import AstraDBCache, AstraDBSemanticCache
@ -41,7 +44,22 @@ def astradb_cache() -> Iterator[AstraDBCache]:
namespace=os.environ.get("ASTRA_DB_KEYSPACE"), namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
) )
yield cache yield cache
cache.astra_db.delete_collection("lc_integration_test_cache") cache.collection.astra_db.delete_collection("lc_integration_test_cache")
@pytest.fixture
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
cache = AstraDBCache(
collection_name="lc_integration_test_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield cache
await cache.async_collection.astra_db.delete_collection(
"lc_integration_test_cache_async"
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -55,46 +73,87 @@ def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
embedding=fake_embe, embedding=fake_embe,
) )
yield sem_cache yield sem_cache
sem_cache.astra_db.delete_collection("lc_integration_test_cache") sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
@pytest.fixture
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
fake_embe = FakeEmbeddings()
sem_cache = AstraDBSemanticCache(
collection_name="lc_integration_test_sem_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
embedding=fake_embe,
setup_mode=SetupMode.ASYNC,
)
yield sem_cache
sem_cache.collection.astra_db.delete_collection(
"lc_integration_test_sem_cache_async"
)
@pytest.mark.requires("astrapy") @pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBCaches: class TestAstraDBCaches:
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None: def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
set_llm_cache(astradb_cache) self.do_cache_test(FakeLLM(), astradb_cache, "foo")
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
def test_astradb_semantic_cache(
self, astradb_semantic_cache: AstraDBSemanticCache
) -> None:
llm = FakeLLM() llm = FakeLLM()
params = llm.dict() self.do_cache_test(llm, astradb_semantic_cache, "bar")
params["stop"] = None output = llm.generate(["bar"]) # 'fizz' is erased away now
llm_string = str(sorted([(k, v) for k, v in params.items()])) assert output != LLMResult(
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output) # noqa: T201
expected_output = LLMResult(
generations=[[Generation(text="fizz")]], generations=[[Generation(text="fizz")]],
llm_output={}, llm_output={},
) )
print(expected_output) # noqa: T201 astradb_semantic_cache.clear()
assert output == expected_output
astradb_cache.clear()
def test_cassandra_semantic_cache( async def test_astradb_semantic_cache_async(
self, astradb_semantic_cache: AstraDBSemanticCache self, async_astradb_semantic_cache: AstraDBSemanticCache
) -> None: ) -> None:
set_llm_cache(astradb_semantic_cache)
llm = FakeLLM() llm = FakeLLM()
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
assert output != LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
await async_astradb_semantic_cache.aclear()
@staticmethod
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict() params = llm.dict()
params["stop"] = None params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo' output = llm.generate([prompt])
expected_output = LLMResult( expected_output = LLMResult(
generations=[[Generation(text="fizz")]], generations=[[Generation(text="fizz")]],
llm_output={}, llm_output={},
) )
assert output == expected_output assert output == expected_output
# clear the cache # clear the cache
astradb_semantic_cache.clear() cache.clear()
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output @staticmethod
astradb_semantic_cache.clear() async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate([prompt])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
await cache.aclear()