diff --git a/libs/community/langchain_community/chat_message_histories/cassandra.py b/libs/community/langchain_community/chat_message_histories/cassandra.py index 3eb3d673ac0..4960f20a2b0 100644 --- a/libs/community/langchain_community/chat_message_histories/cassandra.py +++ b/libs/community/langchain_community/chat_message_histories/cassandra.py @@ -2,11 +2,10 @@ from __future__ import annotations import json -import typing import uuid -from typing import List +from typing import TYPE_CHECKING, List, Optional -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from cassandra.cluster import Session from langchain_core.chat_history import BaseChatMessageHistory @@ -26,8 +25,8 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): Args: session_id: arbitrary key that is used to store the messages of a single chat session. - session: a Cassandra `Session` object (an open DB connection) - keyspace: name of the keyspace to use. + session: Cassandra driver session. If not provided, it is resolved from cassio. + keyspace: Cassandra key space. If not provided, it is resolved from cassio. table_name: name of the table to use. ttl_seconds: time-to-live (seconds) for automatic expiration of stored entries. None (default) for no expiration. @@ -36,10 +35,10 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): def __init__( self, session_id: str, - session: Session, - keyspace: str, + session: Optional[Session] = None, + keyspace: Optional[str] = None, table_name: str = DEFAULT_TABLE_NAME, - ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS, + ttl_seconds: Optional[int] = DEFAULT_TTL_SECONDS, ) -> None: try: from cassio.table import ClusteredCassandraTable @@ -74,7 +73,11 @@ class CassandraChatMessageHistory(BaseChatMessageHistory): return messages def add_message(self, message: BaseMessage) -> None: - """Write a message to the table""" + """Write a message to the table + + Args: + message: A message to write. + """ this_row_id = uuid.uuid1() self.table.put( partition_id=self.session_id,