diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py new file mode 100644 index 00000000000..8983f3c56b6 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -0,0 +1,123 @@ +import json +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + List, + Optional, + Sequence, + Union, +) + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + + +def default_page_content_mapper(row: Any) -> str: + if hasattr(row, "_asdict"): + return json.dumps(row._asdict()) + return json.dumps(row) + + +_NOT_SET = object() + +if TYPE_CHECKING: + from cassandra.cluster import Session + from cassandra.pool import Host + from cassandra.query import Statement + + +class CassandraLoader(BaseLoader): + def __init__( + self, + table: Optional[str] = None, + session: Optional["Session"] = None, + keyspace: Optional[str] = None, + query: Optional[Union[str, "Statement"]] = None, + page_content_mapper: Callable[[Any], str] = default_page_content_mapper, + metadata_mapper: Callable[[Any], dict] = lambda _: {}, + *, + query_parameters: Union[dict, Sequence] = None, + query_timeout: Optional[float] = _NOT_SET, + query_trace: bool = False, + query_custom_payload: dict = None, + query_execution_profile: Any = _NOT_SET, + query_paging_state: Any = None, + query_host: "Host" = None, + query_execute_as: str = None, + ) -> None: + """ + Document Loader for Apache Cassandra. + + Args: + table: The table to load the data from. + (do not use together with the query parameter) + session: The cassandra driver session. + If not provided, the cassio resolved session will be used. + keyspace: The keyspace of the table. + If not provided, the cassio resolved keyspace will be used. + query: The query used to load the data. + (do not use together with the table parameter) + page_content_mapper: a function to convert a row to string page content. + query_parameters: The query parameters used when calling session.execute . + query_timeout: The query timeout used when calling session.execute . + query_custom_payload: The query custom_payload used when calling + session.execute . + query_execution_profile: The query execution_profile used when calling + session.execute . + query_host: The query host used when calling session.execute . + query_execute_as: The query execute_as used when calling session.execute . + """ + if query and table: + raise ValueError("Cannot specify both query and table.") + + if not query and not table: + raise ValueError("Must specify query or table.") + + if not session or (table and not keyspace): + try: + from cassio.config import check_resolve_keyspace, check_resolve_session + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent cassio package." + "Please install it with `pip install --upgrade cassio`." + ) + + if table: + _keyspace = keyspace or check_resolve_keyspace(keyspace) + self.query = f"SELECT * FROM {_keyspace}.{table};" + self.metadata = {"table": table, "keyspace": _keyspace} + else: + self.query = query + self.metadata = {} + + self.session = session or check_resolve_session(session) + self.page_content_mapper = page_content_mapper + self.metadata_mapper = metadata_mapper + + self.query_kwargs = { + "parameters": query_parameters, + "trace": query_trace, + "custom_payload": query_custom_payload, + "paging_state": query_paging_state, + "host": query_host, + "execute_as": query_execute_as, + } + if query_timeout is not _NOT_SET: + self.query_kwargs["timeout"] = query_timeout + + if query_execution_profile is not _NOT_SET: + self.query_kwargs["execution_profile"] = query_execution_profile + + def load(self) -> List[Document]: + return list(self.lazy_load()) + + def lazy_load(self) -> Iterator[Document]: + for row in self.session.execute(self.query, **self.query_kwargs): + metadata = self.metadata.copy() + metadata.update(self.metadata_mapper(row)) + yield Document( + page_content=self.page_content_mapper(row), metadata=metadata + ) diff --git a/libs/community/tests/integration_tests/.env.example b/libs/community/tests/integration_tests/.env.example index 4ce3040f343..99be8383533 100644 --- a/libs/community/tests/integration_tests/.env.example +++ b/libs/community/tests/integration_tests/.env.example @@ -14,6 +14,13 @@ ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token # ASTRA_DB_KEYSPACE=your_astra_db_namespace +# cassandra +CASSANDRA_CONTACT_POINTS=127.0.0.1 +# CASSANDRA_USERNAME=your_cassandra_username +# CASSANDRA_PASSWORD=your_cassandra_password +# CASSANDRA_KEYSPACE=your_cassandra_keyspace + + # pinecone # your api key from left menu "API Keys" in https://app.pinecone.io PINECONE_API_KEY=your_pinecone_api_key_here diff --git a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py new file mode 100644 index 00000000000..59db9f7d3e8 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py @@ -0,0 +1,121 @@ +""" +Test of Cassandra document loader class `CassandraLoader` +""" +import os +from typing import Any + +import pytest +from langchain_core.documents import Document + +from langchain_community.document_loaders.cassandra import CassandraLoader + +CASSANDRA_DEFAULT_KEYSPACE = "docloader_test_keyspace" +CASSANDRA_TABLE = "docloader_test_table" + + +@pytest.fixture(autouse=True, scope="session") +def keyspace() -> str: + import cassio + from cassandra.cluster import Cluster + from cassio.config import check_resolve_session, resolve_keyspace + from cassio.table.tables import PlainCassandraTable + + if any( + env_var in os.environ + for env_var in [ + "CASSANDRA_CONTACT_POINTS", + "ASTRA_DB_APPLICATION_TOKEN", + "ASTRA_DB_INIT_STRING", + ] + ): + cassio.init(auto=True) + session = check_resolve_session() + else: + cluster = Cluster() + session = cluster.connect() + keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE + cassio.init(session=session, keyspace=keyspace) + + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " + f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + ) + ) + + # We use a cassio table by convenience to seed the DB + table = PlainCassandraTable( + table=CASSANDRA_TABLE, keyspace=keyspace, session=session + ) + table.put(row_id="id1", body_blob="text1") + table.put(row_id="id2", body_blob="text2") + + yield keyspace + + session.execute(f"DROP TABLE IF EXISTS {keyspace}.{CASSANDRA_TABLE}") + + +def test_loader_table(keyspace: str) -> None: + loader = CassandraLoader(table=CASSANDRA_TABLE) + assert loader.load() == [ + Document( + page_content='{"row_id": "id1", "body_blob": "text1"}', + metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, + ), + Document( + page_content='{"row_id": "id2", "body_blob": "text2"}', + metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, + ), + ] + + +def test_loader_query(keyspace: str) -> None: + loader = CassandraLoader( + query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}" + ) + assert loader.load() == [ + Document(page_content='{"body_blob": "text1"}'), + Document(page_content='{"body_blob": "text2"}'), + ] + + +def test_loader_page_content_mapper(keyspace: str) -> None: + def mapper(row: Any) -> str: + return str(row.body_blob) + + loader = CassandraLoader(table=CASSANDRA_TABLE, page_content_mapper=mapper) + assert loader.load() == [ + Document( + page_content="text1", + metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, + ), + Document( + page_content="text2", + metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, + ), + ] + + +def test_loader_metadata_mapper(keyspace: str) -> None: + def mapper(row: Any) -> dict: + return {"id": row.row_id} + + loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper) + assert loader.load() == [ + Document( + page_content='{"row_id": "id1", "body_blob": "text1"}', + metadata={ + "table": CASSANDRA_TABLE, + "keyspace": keyspace, + "id": "id1", + }, + ), + Document( + page_content='{"row_id": "id2", "body_blob": "text2"}', + metadata={ + "table": CASSANDRA_TABLE, + "keyspace": keyspace, + "id": "id2", + }, + ), + ]