diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py new file mode 100644 index 00000000000..a5c87aa97f4 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -0,0 +1,101 @@ +import json +import logging +import threading +from queue import Queue +from typing import Any, Callable, Dict, Iterator, List, Optional + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + +logger = logging.getLogger(__name__) + + +class AstraDBLoader(BaseLoader): + """Load DataStax Astra DB documents.""" + + def __init__( + self, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed + namespace: Optional[str] = None, + filter_criteria: Optional[Dict[str, Any]] = None, + projection: Optional[Dict[str, Any]] = None, + find_options: Optional[Dict[str, Any]] = None, + nb_prefetched: int = 1000, + extraction_function: Callable[[Dict], str] = json.dumps, + ) -> None: + try: + from astrapy.db import AstraDB + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.filter = filter_criteria + self.projection = projection + self.find_options = find_options or {} + self.nb_prefetched = nb_prefetched + self.extraction_function = extraction_function + + if astra_db_client is not None: + astra_db = astra_db_client + else: + astra_db = AstraDB( + token=token, + api_endpoint=api_endpoint, + namespace=namespace, + ) + self.collection = astra_db.collection(collection_name) + + def load(self) -> List[Document]: + """Eagerly load the content.""" + return list(self.lazy_load()) + + def lazy_load(self) -> Iterator[Document]: + queue = Queue(self.nb_prefetched) + t = threading.Thread(target=self.fetch_results, args=(queue,)) + t.start() + while True: + doc = queue.get() + if doc is None: + break + yield doc + t.join() + + def fetch_results(self, queue: Queue): + self.fetch_page_result(queue) + while self.find_options.get("pageState"): + self.fetch_page_result(queue) + queue.put(None) + + def fetch_page_result(self, queue: Queue): + res = self.collection.find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + ) + self.find_options["pageState"] = res["data"].get("nextPageState") + for doc in res["data"]["documents"]: + queue.put( + Document( + page_content=self.extraction_function(doc), + metadata={ + "namespace": self.collection.astra_db.namespace, + "api_endpoint": self.collection.astra_db.base_url, + "collection": self.collection.collection_name, + }, + ) + ) diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py new file mode 100644 index 00000000000..76489b26f4c --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -0,0 +1,96 @@ +""" +Test of Astra DB document loader class `AstraDBLoader` + +Required to run this test: + - a recent `astrapy` Python package available + - an Astra DB instance; + - the two environment variables set: + export ASTRA_DB_API_ENDPOINT="https://-us-east1.apps.astra.datastax.com" + export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........." + - optionally this as well (otherwise defaults are used): + export ASTRA_DB_KEYSPACE="my_keyspace" +""" +import json +import os +import uuid + +import pytest + +from langchain_community.document_loaders.astradb import AstraDBLoader + +ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN") +ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT") +ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE") + + +def _has_env_vars() -> bool: + return all([ASTRA_DB_APPLICATION_TOKEN, ASTRA_DB_API_ENDPOINT]) + + +@pytest.fixture +def astra_db_collection(): + from astrapy.db import AstraDB + + astra_db = AstraDB( + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + ) + collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" + collection = astra_db.create_collection(collection_name) + + yield collection + + astra_db.delete_collection(collection_name) + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +class TestAstraDB: + def test_astradb_loader(self, astra_db_collection) -> None: + astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + astra_db_collection.insert_many( + [{"foo": "bar2", "baz": "qux"}] * 4 + [{"foo": "bar", "baz": "qux"}] * 4 + ) + + loader = AstraDBLoader( + astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + nb_prefetched=1, + projection={"foo": 1}, + find_options={"limit": 22}, + filter_criteria={"foo": "bar"}, + ) + docs = loader.load() + + assert len(docs) == 22 + ids = set() + for doc in docs: + content = json.loads(doc.page_content) + assert content["foo"] == "bar" + assert "baz" not in content + assert content["_id"] not in ids + ids.add(content["_id"]) + assert doc.metadata == { + "namespace": astra_db_collection.astra_db.namespace, + "api_endpoint": astra_db_collection.astra_db.base_url, + "collection": astra_db_collection.collection_name, + } + + def test_extraction_function(self, astra_db_collection) -> None: + astra_db_collection.insert_many([{"foo": "bar", "baz": "qux"}] * 20) + + loader = AstraDBLoader( + astra_db_collection.collection_name, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=ASTRA_DB_API_ENDPOINT, + namespace=ASTRA_DB_KEYSPACE, + find_options={"limit": 30}, + extraction_function=lambda x: x["foo"], + ) + docs = loader.lazy_load() + doc = next(docs) + + assert doc.page_content == "bar"