diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index a5c87aa97f4..b8f33b19662 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -2,12 +2,24 @@ import json import logging import threading from queue import Queue -from typing import Any, Callable, Dict, Iterator, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, +) from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader +if TYPE_CHECKING: + from astrapy.db import AstraDB, AsyncAstraDB + logger = logging.getLogger(__name__) @@ -19,7 +31,8 @@ class AstraDBLoader(BaseLoader): collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, - astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + astra_db_client: Optional["AstraDB"] = None, + async_astra_db_client: Optional["AsyncAstraDB"] = None, namespace: Optional[str] = None, filter_criteria: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, @@ -36,34 +49,60 @@ class AstraDBLoader(BaseLoader): ) # Conflicting-arg checks: - if astra_db_client is not None: + if astra_db_client is not None or async_astra_db_client is not None: if token is not None or api_endpoint is not None: raise ValueError( - "You cannot pass 'astra_db_client' to AstraDB if passing " - "'token' and 'api_endpoint'." + "You cannot pass 'astra_db_client' or 'async_astra_db_client' to " + "AstraDB if passing 'token' and 'api_endpoint'." ) - + self.collection_name = collection_name self.filter = filter_criteria self.projection = projection self.find_options = find_options or {} self.nb_prefetched = nb_prefetched self.extraction_function = extraction_function - if astra_db_client is not None: - astra_db = astra_db_client - else: + astra_db = astra_db_client + async_astra_db = async_astra_db_client + + if token and api_endpoint: astra_db = AstraDB( token=token, api_endpoint=api_endpoint, namespace=namespace, ) - self.collection = astra_db.collection(collection_name) + try: + from astrapy.db import AsyncAstraDB + + async_astra_db = AsyncAstraDB( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) + except (ImportError, ModuleNotFoundError): + pass + if not astra_db and not async_astra_db: + raise ValueError( + "Must provide 'astra_db_client' or 'async_astra_db_client' or 'token' " + "and 'api_endpoint'" + ) + self.collection = astra_db.collection(collection_name) if astra_db else None + if async_astra_db: + from astrapy.db import AsyncAstraDBCollection + + self.async_collection = AsyncAstraDBCollection( + astra_db=async_astra_db, collection_name=collection_name + ) + else: + self.async_collection = None def load(self) -> List[Document]: """Eagerly load the content.""" return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: + if not self.collection: + raise ValueError("Missing AstraDB client") queue = Queue(self.nb_prefetched) t = threading.Thread(target=self.fetch_results, args=(queue,)) t.start() @@ -74,6 +113,29 @@ class AstraDBLoader(BaseLoader): yield doc t.join() + async def aload(self) -> List[Document]: + """Load data into Document objects.""" + return [doc async for doc in self.alazy_load()] + + async def alazy_load(self) -> AsyncIterator[Document]: + if not self.async_collection: + raise ValueError("Missing AsyncAstraDB client") + async for doc in self.async_collection.paginated_find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=True, + ): + yield Document( + page_content=self.extraction_function(doc), + metadata={ + "namespace": self.async_collection.astra_db.namespace, + "api_endpoint": self.async_collection.astra_db.base_url, + "collection": self.collection_name, + }, + ) + def fetch_results(self, queue: Queue): self.fetch_page_result(queue) while self.find_options.get("pageState"): diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py index 76489b26f4c..e6b00434280 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_astradb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -13,11 +13,18 @@ Required to run this test: import json import os import uuid +from typing import TYPE_CHECKING import pytest from langchain_community.document_loaders.astradb import AstraDBLoader +if TYPE_CHECKING: + from astrapy.db import ( + AstraDBCollection, + AsyncAstraDBCollection, + ) + ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN") ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT") ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE") @@ -28,7 +35,7 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db_collection(): +def astra_db_collection() -> "AstraDBCollection": from astrapy.db import AstraDB astra_db = AstraDB( @@ -38,21 +45,41 @@ def astra_db_collection(): ) collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" collection = astra_db.create_collection(collection_name) + collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) yield collection astra_db.delete_collection(collection_name) +@pytest.fixture +async def async_astra_db_collection() -> "AsyncAstraDBCollection": + from astrapy.db import AsyncAstraDB + + astra_db = AsyncAstraDB( + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + ) + collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" + collection = await astra_db.create_collection(collection_name) + await collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + await collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + yield collection + + await astra_db.delete_collection(collection_name) + + @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDB: - def test_astradb_loader(self, astra_db_collection) -> None: - astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) - astra_db_collection.insert_many( - [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 - ) - + def test_astradb_loader(self, astra_db_collection: "AstraDBCollection") -> None: loader = AstraDBLoader( astra_db_collection.collection_name, token=ASTRA_DB_APPLICATION_TOKEN, @@ -79,9 +106,9 @@ class TestAstraDB: "collection": astra_db_collection.collection_name, } - def test_extraction_function(self, astra_db_collection) -> None: - astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) - + def test_extraction_function( + self, astra_db_collection: "AstraDBCollection" + ) -> None: loader = AstraDBLoader( astra_db_collection.collection_name, token=ASTRA_DB_APPLICATION_TOKEN, @@ -94,3 +121,51 @@ class TestAstraDB: doc = next(docs) assert doc.page_content == "bar" + + async def test_astradb_loader_async( + self, async_astra_db_collection: "AsyncAstraDBCollection" + ) -> None: + await async_astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + await async_astra_db_collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + loader = AstraDBLoader( + async_astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + nb_prefetched=1, + projection={"foo": 1}, + find_options={"limit": 22}, + filter_criteria={"foo": "bar"}, + ) + docs = await loader.aload() + + assert len(docs) == 22 + ids = set() + for doc in docs: + content = json.loads(doc.page_content) + assert content["foo"] == "bar" + assert "baz" not in content + assert content["_id"] not in ids + ids.add(content["_id"]) + assert doc.metadata == { + "namespace": async_astra_db_collection.astra_db.namespace, + "api_endpoint": async_astra_db_collection.astra_db.base_url, + "collection": async_astra_db_collection.collection_name, + } + + async def test_extraction_function_async( + self, async_astra_db_collection: "AsyncAstraDBCollection" + ) -> None: + loader = AstraDBLoader( + async_astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + extraction_function=lambda x: x["foo"], + ) + doc = await anext(loader.alazy_load()) + assert doc.page_content == "bar"