diff --git a/libs/community/langchain_community/chat_message_histories/cassandra.py b/libs/community/langchain_community/chat_message_histories/cassandra.py index 4960f20a2b0..1017cbe3952 100644 --- a/libs/community/langchain_community/chat_message_histories/cassandra.py +++ b/libs/community/langchain_community/chat_message_histories/cassandra.py @@ -3,10 +3,13 @@ from __future__ import annotations import json import uuid -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence + +from langchain_community.utilities.cassandra import SetupMode if TYPE_CHECKING: from cassandra.cluster import Session + from cassio.table.table_types import RowType from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( @@ -19,19 +22,14 @@ DEFAULT_TABLE_NAME = "message_store" DEFAULT_TTL_SECONDS = None +def _rows_to_messages(rows: Iterable[RowType]) -> List[BaseMessage]: + message_blobs = [row["body_blob"] for row in rows][::-1] + items = [json.loads(message_blob) for message_blob in message_blobs] + messages = messages_from_dict(items) + return messages + + class CassandraChatMessageHistory(BaseChatMessageHistory): - """Chat message history that stores history in Cassandra. - - Args: - session_id: arbitrary key that is used to store the messages - of a single chat session. - session: Cassandra driver session. If not provided, it is resolved from cassio. - keyspace: Cassandra key space. If not provided, it is resolved from cassio. - table_name: name of the table to use. - ttl_seconds: time-to-live (seconds) for automatic expiration - of stored entries. None (default) for no expiration. - """ - def __init__( self, session_id: str, @@ -39,7 +37,22 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): keyspace: Optional[str] = None, table_name: str = DEFAULT_TABLE_NAME, ttl_seconds: Optional[int] = DEFAULT_TTL_SECONDS, + *, + setup_mode: SetupMode = SetupMode.SYNC, ) -> None: + """Chat message history that stores history in Cassandra. + + Args: + session_id: arbitrary key that is used to store the messages + of a single chat session. + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. If not provided, it is resolved from cassio. + table_name: name of the table to use. + ttl_seconds: time-to-live (seconds) for automatic expiration + of stored entries. None (default) for no expiration. + setup_mode: mode used to create the Cassandra table (SYNC, ASYNC or OFF). + """ try: from cassio.table import ClusteredCassandraTable except (ImportError, ModuleNotFoundError): @@ -49,6 +62,9 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): ) self.session_id = session_id self.ttl_seconds = ttl_seconds + kwargs: Dict[str, Any] = {} + if setup_mode == SetupMode.ASYNC: + kwargs["async_setup"] = True self.table = ClusteredCassandraTable( session=session, keyspace=keyspace, @@ -56,21 +72,26 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): ttl_seconds=ttl_seconds, primary_key_type=["TEXT", "TIMEUUID"], ordering_in_partition="DESC", + skip_provisioning=setup_mode == SetupMode.OFF, + **kwargs, ) @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all session messages from DB""" # The latest are returned, in chronological order - message_blobs = [ - row["body_blob"] - for row in self.table.get_partition( - partition_id=self.session_id, - ) - ][::-1] - items = [json.loads(message_blob) for message_blob in message_blobs] - messages = messages_from_dict(items) - return messages + rows = self.table.get_partition( + partition_id=self.session_id, + ) + return _rows_to_messages(rows) + + async def aget_messages(self) -> List[BaseMessage]: + """Retrieve all session messages from DB""" + # The latest are returned, in chronological order + rows = await self.table.aget_partition( + partition_id=self.session_id, + ) + return _rows_to_messages(rows) def add_message(self, message: BaseMessage) -> None: """Write a message to the table @@ -86,6 +107,20 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): ttl_seconds=self.ttl_seconds, ) + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + for message in messages: + this_row_id = uuid.uuid1() + await self.table.aput( + partition_id=self.session_id, + row_id=this_row_id, + body_blob=json.dumps(message_to_dict(message)), + ttl_seconds=self.ttl_seconds, + ) + def clear(self) -> None: """Clear session memory from DB""" self.table.delete_partition(self.session_id) + + async def aclear(self) -> None: + """Clear session memory from DB""" + await self.table.adelete_partition(self.session_id) diff --git a/libs/community/tests/integration_tests/memory/test_memory_cassandra.py b/libs/community/tests/integration_tests/memory/test_memory_cassandra.py index 6ff03ba6e77..5e7fc0535b4 100644 --- a/libs/community/tests/integration_tests/memory/test_memory_cassandra.py +++ b/libs/community/tests/integration_tests/memory/test_memory_cassandra.py @@ -1,6 +1,6 @@ import os import time -from typing import Optional +from typing import Any, Optional from langchain.memory import ConversationBufferMemory from langchain_core.messages import AIMessage, HumanMessage @@ -37,13 +37,15 @@ def _chat_message_history( # drop table if required if drop: session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}") - # + + kwargs: Any = {} if ttl_seconds is None else {"ttl_seconds": ttl_seconds} + return CassandraChatMessageHistory( session_id=session_id, session=session, keyspace=keyspace, table_name=table_name, - **({} if ttl_seconds is None else {"ttl_seconds": ttl_seconds}), + **kwargs, )