diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index 0ad635092e8..b2001682942 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -60,12 +60,14 @@ from langchain_core.load.load import loads from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils import get_from_env +from langchain_community.utilities.astradb import AstraDBEnvironment from langchain_community.vectorstores.redis import Redis as RedisVectorstore logger = logging.getLogger(__file__) if TYPE_CHECKING: import momento + from astrapy.db import AstraDB from cassandra.cluster import Session as CassandraSession @@ -1262,7 +1264,7 @@ class AstraDBCache(BaseCache): collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, ): """ @@ -1278,39 +1280,17 @@ class AstraDBCache(BaseCache): namespace (Optional[str]): namespace (aka keyspace) where the collection is created. Defaults to the database's "default namespace". """ - try: - from astrapy.db import ( - AstraDB as LibAstraDB, - ) - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - # Conflicting-arg checks: - if astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." - ) - - self.collection_name = collection_name - self.token = token - self.api_endpoint = api_endpoint - self.namespace = namespace - - if astra_db_client is not None: - self.astra_db = astra_db_client - else: - self.astra_db = LibAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - self.collection = self.astra_db.create_collection( - collection_name=self.collection_name, + astra_env = AstraDBEnvironment( + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + namespace=namespace, ) + self.astra_db = astra_env.astra_db + self.collection = self.astra_db.create_collection( + collection_name=collection_name, + ) + self.collection_name = collection_name @staticmethod def _make_id(prompt: str, llm_string: str) -> str: @@ -1364,7 +1344,7 @@ class AstraDBCache(BaseCache): def delete(self, prompt: str, llm_string: str) -> None: """Evict from cache if there's an entry.""" doc_id = self._make_id(prompt, llm_string) - return self.collection.delete_one(doc_id) + self.collection.delete_one(doc_id) def clear(self, **kwargs: Any) -> None: """Clear cache. This is for all LLMs at once.""" @@ -1395,7 +1375,7 @@ class AstraDBSemanticCache(BaseCache): collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, embedding: Embeddings, metric: Optional[str] = None, @@ -1423,22 +1403,13 @@ class AstraDBSemanticCache(BaseCache): The default score threshold is tuned to the default metric. Tune it carefully yourself if switching to another distance metric. """ - try: - from astrapy.db import ( - AstraDB as LibAstraDB, - ) - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - # Conflicting-arg checks: - if astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." - ) + astra_env = AstraDBEnvironment( + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + namespace=namespace, + ) + self.astra_db = astra_env.astra_db self.embedding = embedding self.metric = metric @@ -1457,18 +1428,7 @@ class AstraDBSemanticCache(BaseCache): self.embedding_dimension = self._get_embedding_dimension() self.collection_name = collection_name - self.token = token - self.api_endpoint = api_endpoint - self.namespace = namespace - if astra_db_client is not None: - self.astra_db = astra_db_client - else: - self.astra_db = LibAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) self.collection = self.astra_db.create_collection( collection_name=self.collection_name, dimension=self.embedding_dimension, diff --git a/libs/community/langchain_community/chat_message_histories/astradb.py b/libs/community/langchain_community/chat_message_histories/astradb.py index 27e4dc5c936..7257476101a 100644 --- a/libs/community/langchain_community/chat_message_histories/astradb.py +++ b/libs/community/langchain_community/chat_message_histories/astradb.py @@ -3,11 +3,12 @@ from __future__ import annotations import json import time -import typing -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional -if typing.TYPE_CHECKING: - from astrapy.db import AstraDB as LibAstraDB +from langchain_community.utilities.astradb import AstraDBEnvironment + +if TYPE_CHECKING: + from astrapy.db import AstraDB from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( @@ -42,40 +43,22 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory): collection_name: str = DEFAULT_COLLECTION_NAME, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB' + astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, ) -> None: """Create an Astra DB chat message history.""" - try: - from astrapy.db import AstraDB as LibAstraDB - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) + astra_env = AstraDBEnvironment( + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + namespace=namespace, + ) + self.astra_db = astra_env.astra_db - # Conflicting-arg checks: - if astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." - ) + self.collection = self.astra_db.create_collection(collection_name) self.session_id = session_id self.collection_name = collection_name - self.token = token - self.api_endpoint = api_endpoint - self.namespace = namespace - if astra_db_client is not None: - self.astra_db = astra_db_client - else: - self.astra_db = LibAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) - self.collection = self.astra_db.create_collection(self.collection_name) @property def messages(self) -> List[BaseMessage]: # type: ignore diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index 3562d424892..2af1e77fde1 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -16,8 +16,10 @@ from typing import ( ) from langchain_core.documents import Document +from langchain_core.runnables import run_in_executor from langchain_community.document_loaders.base import BaseLoader +from langchain_community.utilities.astradb import AstraDBEnvironment if TYPE_CHECKING: from astrapy.db import AstraDB, AsyncAstraDB @@ -42,21 +44,15 @@ class AstraDBLoader(BaseLoader): nb_prefetched: int = 1000, extraction_function: Callable[[Dict], str] = json.dumps, ) -> None: - try: - from astrapy.db import AstraDB - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - - # Conflicting-arg checks: - if astra_db_client is not None or async_astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " - "AstraDB if passing 'token' and 'api_endpoint'." - ) + astra_env = AstraDBEnvironment( + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + ) + self.astra_env = astra_env + self.collection = astra_env.astra_db.collection(collection_name) self.collection_name = collection_name self.filter = filter_criteria self.projection = projection @@ -64,47 +60,11 @@ class AstraDBLoader(BaseLoader): self.nb_prefetched = nb_prefetched self.extraction_function = extraction_function - astra_db = astra_db_client - async_astra_db = async_astra_db_client - - if token and api_endpoint: - astra_db = AstraDB( - token=token, - api_endpoint=api_endpoint, - namespace=namespace, - ) - try: - from astrapy.db import AsyncAstraDB - - async_astra_db = AsyncAstraDB( - token=token, - api_endpoint=api_endpoint, - namespace=namespace, - ) - except (ImportError, ModuleNotFoundError): - pass - if not astra_db and not async_astra_db: - raise ValueError( - "Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' " - "and 'api_endpoint'" - ) - self.collection = astra_db.collection(collection_name) if astra_db else None - if async_astra_db: - from astrapy.db import AsyncAstraDBCollection - - self.async_collection = AsyncAstraDBCollection( - astra_db=async_astra_db, collection_name=collection_name - ) - else: - self.async_collection = None - def load(self) -> List[Document]: """Eagerly load the content.""" return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: - if not self.collection: - raise ValueError("Missing AstraDB client") queue = Queue(self.nb_prefetched) t = threading.Thread(target=self.fetch_results, args=(queue,)) t.start() @@ -120,9 +80,27 @@ class AstraDBLoader(BaseLoader): return [doc async for doc in self.alazy_load()] async def alazy_load(self) -> AsyncIterator[Document]: - if not self.async_collection: - raise ValueError("Missing AsyncAstraDB client") - async for doc in self.async_collection.paginated_find( + if not self.astra_env.async_astra_db: + iterator = run_in_executor( + None, + self.collection.paginated_find, + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=True, + ) + done = object() + while True: + item = await run_in_executor(None, lambda it: next(it, done), iterator) + if item is done: + break + yield item + return + async_collection = await self.astra_env.async_astra_db.collection( + self.collection_name + ) + async for doc in async_collection.paginated_find( filter=self.filter, options=self.find_options, projection=self.projection, @@ -132,8 +110,8 @@ class AstraDBLoader(BaseLoader): yield Document( page_content=self.extraction_function(doc), metadata={ - "namespace": self.async_collection.astra_db.namespace, - "api_endpoint": self.async_collection.astra_db.base_url, + "namespace": async_collection.astra_db.namespace, + "api_endpoint": async_collection.astra_db.base_url, "collection": self.collection_name, }, ) diff --git a/libs/community/langchain_community/storage/astradb.py b/libs/community/langchain_community/storage/astradb.py index f38874fb835..79d5bebcd37 100644 --- a/libs/community/langchain_community/storage/astradb.py +++ b/libs/community/langchain_community/storage/astradb.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import base64 from abc import ABC, abstractmethod from typing import ( + TYPE_CHECKING, Any, Generic, Iterator, @@ -13,6 +16,11 @@ from typing import ( from langchain_core.stores import BaseStore, ByteStore +from langchain_community.utilities.astradb import AstraDBEnvironment + +if TYPE_CHECKING: + from astrapy.db import AstraDB + V = TypeVar("V") @@ -22,31 +30,19 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, ) -> None: - try: - from astrapy.db import AstraDB, AstraDBCollection - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import a recent astrapy python package. " - "Please install it with `pip install --upgrade astrapy`." - ) - - # Conflicting-arg checks: - if astra_db_client is not None: - if token is not None or api_endpoint is not None: - raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." - ) - - astra_db = astra_db_client or AstraDB( + astra_env = AstraDBEnvironment( token=token, api_endpoint=api_endpoint, + astra_db_client=astra_db_client, namespace=namespace, ) - self.collection = AstraDBCollection(collection_name, astra_db=astra_db) + self.astra_db = astra_env.astra_db + self.collection = self.astra_db.create_collection( + collection_name=collection_name, + ) @abstractmethod def decode_value(self, value: Any) -> Optional[V]: diff --git a/libs/community/langchain_community/utilities/astradb.py b/libs/community/langchain_community/utilities/astradb.py new file mode 100644 index 00000000000..3ad3d327497 --- /dev/null +++ b/libs/community/langchain_community/utilities/astradb.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from astrapy.db import ( + AstraDB, + AsyncAstraDB, + ) + + +class AstraDBEnvironment: + def __init__( + self, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + namespace: Optional[str] = None, + ) -> None: + self.token = token + self.api_endpoint = api_endpoint + astra_db = astra_db_client + self.async_astra_db = async_astra_db_client + self.namespace = namespace + + from astrapy import db + + try: + from astrapy.db import AstraDB + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + + supports_async = hasattr(db, "AsyncAstraDB") + + # Conflicting-arg checks: + if astra_db_client is not None or async_astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " + "AstraDBEnvironment if passing 'token' and 'api_endpoint'." + ) + + if token and api_endpoint: + astra_db = AstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + if supports_async: + self.async_astra_db = db.AsyncAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + + if astra_db: + self.astra_db = astra_db + else: + if self.async_astra_db: + self.astra_db = AstraDB( + token=self.async_astra_db.token, + api_endpoint=self.async_astra_db.base_url, + api_path=self.async_astra_db.api_path, + api_version=self.async_astra_db.api_version, + namespace=self.async_astra_db.namespace, + ) + else: + raise ValueError( + "Must provide 'astra_db_client' or 'async_astra_db_client' or " + "'token' and 'api_endpoint'" + ) + + if not self.async_astra_db and self.astra_db and supports_async: + self.async_astra_db = db.AsyncAstraDB( + token=self.astra_db.token, + api_endpoint=self.astra_db.base_url, + api_path=self.astra_db.api_path, + api_version=self.astra_db.api_version, + namespace=self.astra_db.namespace, + )