mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
community: Add async methods to AstraDBCache (#17415)
Adds async methods to AstraDBCache
This commit is contained in:
parent
e438fe6be9
commit
ca2d4078f3
@ -29,12 +29,14 @@ import uuid
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -56,20 +58,23 @@ except ImportError:
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import LLM, get_prompts
|
||||
from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts
|
||||
from langchain_core.load.dump import dumps
|
||||
from langchain_core.load.load import loads
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.utilities.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
from astrapy.db import AstraDB
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
from cassandra.cluster import Session as CassandraSession
|
||||
|
||||
|
||||
@ -1371,6 +1376,10 @@ class AstraDBCache(BaseCache):
|
||||
(needed to prevent same-prompt-different-model collisions)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
return f"{_hash(prompt)}#{_hash(llm_string)}"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -1378,7 +1387,10 @@ class AstraDBCache(BaseCache):
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
):
|
||||
"""
|
||||
Create an AstraDB cache using a collection for storage.
|
||||
@ -1388,29 +1400,35 @@ class AstraDBCache(BaseCache):
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
astra_db_client (Optional[AstraDB]):
|
||||
*alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client (Optional[AsyncAstraDB]):
|
||||
*alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
pre_delete_collection (bool): whether to delete and re-create the
|
||||
collection. Defaults to False.
|
||||
async_setup (bool): whether to create the collection asynchronously.
|
||||
Enable only if there is a running asyncio event loop. Defaults to False.
|
||||
"""
|
||||
astra_env = AstraDBEnvironment(
|
||||
self.astra_env = _AstraDBCollectionEnvironment(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
return f"{_hash(prompt)}#{_hash(llm_string)}"
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
self.astra_env.ensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
item = self.collection.find_one(
|
||||
filter={
|
||||
@ -1420,18 +1438,27 @@ class AstraDBCache(BaseCache):
|
||||
"body_blob": 1,
|
||||
},
|
||||
)["data"]["document"]
|
||||
if item is not None:
|
||||
generations = _loads_generations(item["body_blob"])
|
||||
# this protects against malformed cached items:
|
||||
if generations is not None:
|
||||
return generations
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return _loads_generations(item["body_blob"]) if item is not None else None
|
||||
|
||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
item = (
|
||||
await self.async_collection.find_one(
|
||||
filter={
|
||||
"_id": doc_id,
|
||||
},
|
||||
projection={
|
||||
"body_blob": 1,
|
||||
},
|
||||
)
|
||||
)["data"]["document"]
|
||||
return _loads_generations(item["body_blob"]) if item is not None else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self.astra_env.ensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
blob = _dumps_generations(return_val)
|
||||
self.collection.upsert(
|
||||
@ -1441,6 +1468,20 @@ class AstraDBCache(BaseCache):
|
||||
},
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||
) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
blob = _dumps_generations(return_val)
|
||||
await self.async_collection.upsert(
|
||||
{
|
||||
"_id": doc_id,
|
||||
"body_blob": blob,
|
||||
},
|
||||
)
|
||||
|
||||
def delete_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> None:
|
||||
@ -1454,14 +1495,42 @@ class AstraDBCache(BaseCache):
|
||||
)[1]
|
||||
return self.delete(prompt, llm_string=llm_string)
|
||||
|
||||
async def adelete_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
A wrapper around `adelete` with the LLM being passed.
|
||||
In case the llm(prompt) calls have a `stop` param, you should pass it here
|
||||
"""
|
||||
llm_string = (
|
||||
await aget_prompts(
|
||||
{**llm.dict(), **{"stop": stop}},
|
||||
[],
|
||||
)
|
||||
)[1]
|
||||
return await self.adelete(prompt, llm_string=llm_string)
|
||||
|
||||
def delete(self, prompt: str, llm_string: str) -> None:
|
||||
"""Evict from cache if there's an entry."""
|
||||
self.astra_env.ensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
self.collection.delete_one(doc_id)
|
||||
|
||||
async def adelete(self, prompt: str, llm_string: str) -> None:
|
||||
"""Evict from cache if there's an entry."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
await self.async_collection.delete_one(doc_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
self.astra_db.truncate_collection(self.collection_name)
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.clear()
|
||||
|
||||
async def aclear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.clear()
|
||||
|
||||
|
||||
ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85
|
||||
@ -1469,6 +1538,42 @@ ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache"
|
||||
ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16
|
||||
|
||||
|
||||
_unset = ["unset"]
|
||||
|
||||
|
||||
class _CachedAwaitable:
|
||||
"""Caches the result of an awaitable so it can be awaited multiple times"""
|
||||
|
||||
def __init__(self, awaitable: Awaitable[Any]):
|
||||
self.awaitable = awaitable
|
||||
self.result = _unset
|
||||
|
||||
def __await__(self) -> Generator:
|
||||
if self.result is _unset:
|
||||
self.result = yield from self.awaitable.__await__()
|
||||
return self.result
|
||||
|
||||
|
||||
def _reawaitable(func: Callable) -> Callable:
|
||||
"""Makes an async function result awaitable multiple times"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> _CachedAwaitable:
|
||||
return _CachedAwaitable(func(*args, **kwargs))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable:
|
||||
"""Least-recently-used async cache decorator.
|
||||
Equivalent to functools.lru_cache for async functions"""
|
||||
|
||||
def decorating_function(user_function: Callable) -> Callable:
|
||||
return lru_cache(maxsize, typed)(_reawaitable(user_function))
|
||||
|
||||
return decorating_function
|
||||
|
||||
|
||||
class AstraDBSemanticCache(BaseCache):
|
||||
"""
|
||||
Cache that uses Astra DB as a vector-store backend for semantic
|
||||
@ -1479,7 +1584,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
in the document metadata.
|
||||
|
||||
You can choose the preferred similarity (or use the API default) --
|
||||
remember the threshold might require metric-dependend tuning.
|
||||
remember the threshold might require metric-dependent tuning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1489,7 +1594,10 @@ class AstraDBSemanticCache(BaseCache):
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding: Embeddings,
|
||||
metric: Optional[str] = None,
|
||||
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
|
||||
@ -1502,10 +1610,17 @@ class AstraDBSemanticCache(BaseCache):
|
||||
token (Optional[str]): API token for Astra DB usage.
|
||||
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||
astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client (Optional[AsyncAstraDB]):
|
||||
*alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
setup_mode (SetupMode): mode used to create the collection in the DB
|
||||
(SYNC, ASYNC or OFF). Defaults to SYNC.
|
||||
pre_delete_collection (bool): whether to delete and re-create the
|
||||
collection. Defaults to False.
|
||||
embedding (Embedding): Embedding provider for semantic
|
||||
encoding and search.
|
||||
metric: the function to use for evaluating similarity of text embeddings.
|
||||
@ -1516,17 +1631,10 @@ class AstraDBSemanticCache(BaseCache):
|
||||
The default score threshold is tuned to the default metric.
|
||||
Tune it carefully yourself if switching to another distance metric.
|
||||
"""
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
|
||||
self.embedding = embedding
|
||||
self.metric = metric
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.collection_name = collection_name
|
||||
|
||||
# The contract for this class has separate lookup and update:
|
||||
# in order to spare some embedding calculations we cache them between
|
||||
@ -1538,25 +1646,47 @@ class AstraDBSemanticCache(BaseCache):
|
||||
return self.embedding.embed_query(text=text)
|
||||
|
||||
self._get_embedding = _cache_embedding
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.collection_name = collection_name
|
||||
@_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE)
|
||||
async def _acache_embedding(text: str) -> List[float]:
|
||||
return await self.embedding.aembed_query(text=text)
|
||||
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
dimension=self.embedding_dimension,
|
||||
metric=self.metric,
|
||||
self._aget_embedding = _acache_embedding
|
||||
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
embedding_dimension = self._aget_embedding_dimension()
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.astra_env = _AstraDBCollectionEnvironment(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
embedding_dimension=embedding_dimension,
|
||||
metric=metric,
|
||||
)
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
def _get_embedding_dimension(self) -> int:
|
||||
return len(self._get_embedding(text="This is a sample sentence."))
|
||||
|
||||
async def _aget_embedding_dimension(self) -> int:
|
||||
return len(await self._aget_embedding(text="This is a sample sentence."))
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
return f"{_hash(prompt)}#{_hash(llm_string)}"
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
self.astra_env.ensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
embedding_vector = self._get_embedding(text=prompt)
|
||||
@ -1571,6 +1701,25 @@ class AstraDBSemanticCache(BaseCache):
|
||||
}
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
|
||||
) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
embedding_vector = await self._aget_embedding(text=prompt)
|
||||
body = _dumps_generations(return_val)
|
||||
#
|
||||
await self.async_collection.upsert(
|
||||
{
|
||||
"_id": doc_id,
|
||||
"body_blob": body,
|
||||
"llm_string_hash": llm_string_hash,
|
||||
"$vector": embedding_vector,
|
||||
}
|
||||
)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
hit_with_id = self.lookup_with_id(prompt, llm_string)
|
||||
@ -1579,6 +1728,14 @@ class AstraDBSemanticCache(BaseCache):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
hit_with_id = await self.alookup_with_id(prompt, llm_string)
|
||||
if hit_with_id is not None:
|
||||
return hit_with_id[1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def lookup_with_id(
|
||||
self, prompt: str, llm_string: str
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
@ -1586,6 +1743,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
Look up based on prompt and llm_string.
|
||||
If there are hits, return (document_id, cached_entry) for the top hit
|
||||
"""
|
||||
self.astra_env.ensure_db_setup()
|
||||
prompt_embedding: List[float] = self._get_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
|
||||
@ -1604,7 +1762,37 @@ class AstraDBSemanticCache(BaseCache):
|
||||
generations = _loads_generations(hit["body_blob"])
|
||||
if generations is not None:
|
||||
# this protects against malformed cached items:
|
||||
return (hit["_id"], generations)
|
||||
return hit["_id"], generations
|
||||
else:
|
||||
return None
|
||||
|
||||
async def alookup_with_id(
|
||||
self, prompt: str, llm_string: str
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
"""
|
||||
Look up based on prompt and llm_string.
|
||||
If there are hits, return (document_id, cached_entry) for the top hit
|
||||
"""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
prompt_embedding: List[float] = await self._aget_embedding(text=prompt)
|
||||
llm_string_hash = _hash(llm_string)
|
||||
|
||||
hit = await self.async_collection.vector_find_one(
|
||||
vector=prompt_embedding,
|
||||
filter={
|
||||
"llm_string_hash": llm_string_hash,
|
||||
},
|
||||
fields=["body_blob", "_id"],
|
||||
include_similarity=True,
|
||||
)
|
||||
|
||||
if hit is None or hit["$similarity"] < self.similarity_threshold:
|
||||
return None
|
||||
else:
|
||||
generations = _loads_generations(hit["body_blob"])
|
||||
if generations is not None:
|
||||
# this protects against malformed cached items:
|
||||
return hit["_id"], generations
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -1617,14 +1805,41 @@ class AstraDBSemanticCache(BaseCache):
|
||||
)[1]
|
||||
return self.lookup_with_id(prompt, llm_string=llm_string)
|
||||
|
||||
async def alookup_with_id_through_llm(
|
||||
self, prompt: str, llm: LLM, stop: Optional[List[str]] = None
|
||||
) -> Optional[Tuple[str, RETURN_VAL_TYPE]]:
|
||||
llm_string = (
|
||||
await aget_prompts(
|
||||
{**llm.dict(), **{"stop": stop}},
|
||||
[],
|
||||
)
|
||||
)[1]
|
||||
return await self.alookup_with_id(prompt, llm_string=llm_string)
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
"""
|
||||
Given this is a "similarity search" cache, an invalidation pattern
|
||||
that makes sense is first a lookup to get an ID, and then deleting
|
||||
with that ID. This is for the second step.
|
||||
"""
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.delete_one(document_id)
|
||||
|
||||
async def adelete_by_document_id(self, document_id: str) -> None:
|
||||
"""
|
||||
Given this is a "similarity search" cache, an invalidation pattern
|
||||
that makes sense is first a lookup to get an ID, and then deleting
|
||||
with that ID. This is for the second step.
|
||||
"""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.delete_one(document_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear the *whole* semantic cache."""
|
||||
self.astra_db.truncate_collection(self.collection_name)
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.clear()
|
||||
|
||||
async def aclear(self, **kwargs: Any) -> None:
|
||||
"""Clear the *whole* semantic cache."""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.clear()
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.utilities.astradb import _AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
@ -47,7 +47,7 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create an Astra DB chat message history."""
|
||||
astra_env = AstraDBEnvironment(
|
||||
astra_env = _AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
|
@ -19,7 +19,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.utilities.astradb import _AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
@ -44,7 +44,7 @@ class AstraDBLoader(BaseLoader):
|
||||
nb_prefetched: int = 1000,
|
||||
extraction_function: Callable[[Dict], str] = json.dumps,
|
||||
) -> None:
|
||||
astra_env = AstraDBEnvironment(
|
||||
astra_env = _AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
|
@ -16,7 +16,7 @@ from typing import (
|
||||
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.utilities.astradb import _AstraDBEnvironment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
@ -35,7 +35,7 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
astra_env = AstraDBEnvironment(
|
||||
astra_env = _AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
|
@ -1,6 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
import asyncio
|
||||
import inspect
|
||||
from asyncio import InvalidStateError, Task
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import (
|
||||
@ -9,7 +13,13 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class AstraDBEnvironment:
|
||||
class SetupMode(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
|
||||
|
||||
class _AstraDBEnvironment:
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
@ -21,21 +31,20 @@ class AstraDBEnvironment:
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
astra_db = astra_db_client
|
||||
self.async_astra_db = async_astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
self.namespace = namespace
|
||||
|
||||
from astrapy import db
|
||||
|
||||
try:
|
||||
from astrapy.db import AstraDB
|
||||
from astrapy.db import (
|
||||
AstraDB,
|
||||
AsyncAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
supports_async = hasattr(db, "AsyncAstraDB")
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
@ -46,39 +55,115 @@ class AstraDBEnvironment:
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
if supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
if astra_db:
|
||||
self.astra_db = astra_db
|
||||
else:
|
||||
if self.async_astra_db:
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
if async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
self.async_astra_db = AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
)
|
||||
elif async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
)
|
||||
|
||||
|
||||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
||||
metric: Optional[str] = None,
|
||||
) -> None:
|
||||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
||||
|
||||
super().__init__(
|
||||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
self.collection = AstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.astra_db,
|
||||
)
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.async_astra_db,
|
||||
)
|
||||
|
||||
self.async_setup_db_task: Optional[Task] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
async_astra_db = self.async_astra_db
|
||||
|
||||
async def _setup_db() -> None:
|
||||
if pre_delete_collection:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
dimension = await embedding_dimension
|
||||
else:
|
||||
dimension = embedding_dimension
|
||||
await async_astra_db.create_collection(
|
||||
collection_name, dimension=dimension, metric=metric
|
||||
)
|
||||
|
||||
if not self.async_astra_db and self.astra_db and supports_async:
|
||||
self.async_astra_db = db.AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
if pre_delete_collection:
|
||||
self.astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
raise ValueError(
|
||||
"Cannot use an awaitable embedding_dimension with async_setup "
|
||||
"set to False"
|
||||
)
|
||||
self.astra_db.create_collection(
|
||||
collection_name,
|
||||
dimension=embedding_dimension, # type: ignore[arg-type]
|
||||
metric=metric,
|
||||
)
|
||||
|
||||
def ensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
try:
|
||||
self.async_setup_db_task.result()
|
||||
except InvalidStateError:
|
||||
raise ValueError(
|
||||
"Asynchronous setup of the DB not finished. "
|
||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||
"event loop. Consider using their async equivalents."
|
||||
)
|
||||
|
||||
async def aensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
await self.async_setup_db_task
|
||||
|
@ -12,9 +12,12 @@ Required to run this test:
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Iterator
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_community.utilities.astradb import SetupMode
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.cache import AstraDBCache, AstraDBSemanticCache
|
||||
@ -41,7 +44,22 @@ def astradb_cache() -> Iterator[AstraDBCache]:
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield cache
|
||||
cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
cache.collection.astra_db.delete_collection("lc_integration_test_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
|
||||
cache = AstraDBCache(
|
||||
collection_name="lc_integration_test_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield cache
|
||||
await cache.async_collection.astra_db.delete_collection(
|
||||
"lc_integration_test_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -55,46 +73,87 @@ def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
|
||||
embedding=fake_embe,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.astra_db.delete_collection("lc_integration_test_cache")
|
||||
sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
|
||||
fake_embe = FakeEmbeddings()
|
||||
sem_cache = AstraDBSemanticCache(
|
||||
collection_name="lc_integration_test_sem_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
embedding=fake_embe,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.collection.astra_db.delete_collection(
|
||||
"lc_integration_test_sem_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBCaches:
|
||||
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
|
||||
set_llm_cache(astradb_cache)
|
||||
self.do_cache_test(FakeLLM(), astradb_cache, "foo")
|
||||
|
||||
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
|
||||
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
|
||||
|
||||
def test_astradb_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output) # noqa: T201
|
||||
expected_output = LLMResult(
|
||||
self.do_cache_test(llm, astradb_semantic_cache, "bar")
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output) # noqa: T201
|
||||
assert output == expected_output
|
||||
astradb_cache.clear()
|
||||
astradb_semantic_cache.clear()
|
||||
|
||||
def test_cassandra_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
async def test_astradb_semantic_cache_async(
|
||||
self, async_astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
set_llm_cache(astradb_semantic_cache)
|
||||
llm = FakeLLM()
|
||||
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
|
||||
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
await async_astradb_semantic_cache.aclear()
|
||||
|
||||
@staticmethod
|
||||
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
astradb_semantic_cache.clear()
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
astradb_semantic_cache.clear()
|
||||
cache.clear()
|
||||
|
||||
@staticmethod
|
||||
async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
await cache.aclear()
|
||||
|
Loading…
Reference in New Issue
Block a user