From a7af32c274860ee9174830804301491973aaee0a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 15 May 2023 23:43:09 -0700 Subject: [PATCH] Cassandra support for chat history (#4378) (#4764) # Cassandra support for chat history ### Description - Store chat messages in cassandra ### Dependency - cassandra-driver - Python Module ## Before submitting - Added Integration Test ## Who can review? @hwchase17 @agola11 # Your PR Title (What it does) Fixes # (issue) ## Before submitting ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: Co-authored-by: Jinto Jose <129657162+jj701@users.noreply.github.com> --- .../cassandra_chat_message_history.ipynb | 91 +++++++++ langchain/memory/__init__.py | 4 + .../memory/chat_message_histories/__init__.py | 4 + .../chat_message_histories/cassandra.py | 186 ++++++++++++++++++ pyproject.toml | 2 + .../memory/test_cassandra.py | 42 ++++ 6 files changed, 329 insertions(+) create mode 100644 docs/modules/memory/examples/cassandra_chat_message_history.ipynb create mode 100644 langchain/memory/chat_message_histories/cassandra.py create mode 100644 tests/integration_tests/memory/test_cassandra.py diff --git a/docs/modules/memory/examples/cassandra_chat_message_history.ipynb b/docs/modules/memory/examples/cassandra_chat_message_history.ipynb new file mode 100644 index 00000000000..7765f67e310 --- /dev/null +++ b/docs/modules/memory/examples/cassandra_chat_message_history.ipynb @@ -0,0 +1,91 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Cassandra Chat Message History\n", + "\n", + "This notebook goes over how to use Cassandra to store chat message history.\n", + "\n", + "Cassandra is a distributed database that is well suited for storing large amounts of data. \n", + "\n", + "It is a good choice for storing chat message history because it is easy to scale and can handle a large number of writes.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "47a601d2", + "metadata": {}, + "outputs": [], + "source": [ + "# List of contact points to try connecting to Cassandra cluster.\n", + "contact_points = [\"cassandra\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import CassandraChatMessageHistory\n", + "\n", + "message_history = CassandraChatMessageHistory(\n", + " contact_points=contact_points, session_id=\"test-session\"\n", + ")\n", + "\n", + "message_history.add_user_message(\"hi!\")\n", + "\n", + "message_history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "64fc465e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n", + " AIMessage(content='whats up?', additional_kwargs={}, example=False)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "message_history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index b5e9950cc6e..ee491e34bc7 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -3,6 +3,9 @@ from langchain.memory.buffer import ( ConversationStringBufferMemory, ) from langchain.memory.buffer_window import ConversationBufferWindowMemory +from langchain.memory.chat_message_histories.cassandra import ( + CassandraChatMessageHistory, +) from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.file import FileChatMessageHistory @@ -46,4 +49,5 @@ __all__ = [ "CosmosDBChatMessageHistory", "FileChatMessageHistory", "MongoDBChatMessageHistory", + "CassandraChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/__init__.py b/langchain/memory/chat_message_histories/__init__.py index cb646aaaf0d..bdc809406b0 100644 --- a/langchain/memory/chat_message_histories/__init__.py +++ b/langchain/memory/chat_message_histories/__init__.py @@ -1,3 +1,6 @@ +from langchain.memory.chat_message_histories.cassandra import ( + CassandraChatMessageHistory, +) from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory from langchain.memory.chat_message_histories.file import FileChatMessageHistory @@ -18,4 +21,5 @@ __all__ = [ "CosmosDBChatMessageHistory", "FirestoreChatMessageHistory", "MongoDBChatMessageHistory", + "CassandraChatMessageHistory", ] diff --git a/langchain/memory/chat_message_histories/cassandra.py b/langchain/memory/chat_message_histories/cassandra.py new file mode 100644 index 00000000000..d424792a394 --- /dev/null +++ b/langchain/memory/chat_message_histories/cassandra.py @@ -0,0 +1,186 @@ +import json +import logging +from typing import List + +from langchain.schema import ( + AIMessage, + BaseChatMessageHistory, + BaseMessage, + HumanMessage, + _message_to_dict, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + +DEFAULT_KEYSPACE_NAME = "chat_history" +DEFAULT_TABLE_NAME = "message_store" +DEFAULT_USERNAME = "cassandra" +DEFAULT_PASSWORD = "cassandra" +DEFAULT_PORT = 9042 + + +class CassandraChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Cassandra. + Args: + contact_points: list of ips to connect to Cassandra cluster + session_id: arbitrary key that is used to store the messages + of a single chat session. + port: port to connect to Cassandra cluster + username: username to connect to Cassandra cluster + password: password to connect to Cassandra cluster + keyspace_name: name of the keyspace to use + table_name: name of the table to use + """ + + def __init__( + self, + contact_points: List[str], + session_id: str, + port: int = DEFAULT_PORT, + username: str = DEFAULT_USERNAME, + password: str = DEFAULT_PASSWORD, + keyspace_name: str = DEFAULT_KEYSPACE_NAME, + table_name: str = DEFAULT_TABLE_NAME, + ): + self.contact_points = contact_points + self.session_id = session_id + self.port = port + self.username = username + self.password = password + self.keyspace_name = keyspace_name + self.table_name = table_name + + try: + from cassandra import ( + AuthenticationFailed, + OperationTimedOut, + UnresolvableContactPoints, + ) + from cassandra.cluster import Cluster, PlainTextAuthProvider + except ImportError: + raise ValueError( + "Could not import cassandra-driver python package. " + "Please install it with `pip install cassandra-driver`." + ) + + self.cluster: Cluster = Cluster( + contact_points, + port=port, + auth_provider=PlainTextAuthProvider( + username=self.username, password=self.password + ), + ) + + try: + self.session = self.cluster.connect() + except ( + AuthenticationFailed, + UnresolvableContactPoints, + OperationTimedOut, + ) as error: + logger.error( + "Unable to establish connection with \ + cassandra chat message history database" + ) + raise error + + self._prepare_cassandra() + + def _prepare_cassandra(self) -> None: + """Create the keyspace and table if they don't exist yet""" + + from cassandra import OperationTimedOut, Unavailable + + try: + self.session.execute( + f"""CREATE KEYSPACE IF NOT EXISTS + {self.keyspace_name} WITH REPLICATION = + {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }};""" + ) + except (OperationTimedOut, Unavailable) as error: + logger.error( + f"Unable to create cassandra \ + chat message history keyspace: {self.keyspace_name}." + ) + raise error + + self.session.set_keyspace(self.keyspace_name) + + try: + self.session.execute( + f"""CREATE TABLE IF NOT EXISTS + {self.table_name} (id UUID, session_id varchar, + history text, PRIMARY KEY ((session_id), id) );""" + ) + except (OperationTimedOut, Unavailable) as error: + logger.error( + f"Unable to create cassandra \ + chat message history table: {self.table_name}" + ) + raise error + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from Cassandra""" + from cassandra import ReadFailure, ReadTimeout, Unavailable + + try: + rows = self.session.execute( + f"""SELECT * FROM {self.table_name} + WHERE session_id = '{self.session_id}' ;""" + ) + except (Unavailable, ReadTimeout, ReadFailure) as error: + logger.error("Unable to Retreive chat history messages from cassadra") + raise error + + if rows: + items = [json.loads(row.history) for row in rows] + else: + items = [] + + messages = messages_from_dict(items) + + return messages + + def add_user_message(self, message: str) -> None: + self.append(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + self.append(AIMessage(content=message)) + + def append(self, message: BaseMessage) -> None: + """Append the message to the record in Cassandra""" + + import uuid + + from cassandra import Unavailable, WriteFailure, WriteTimeout + + try: + self.session.execute( + """INSERT INTO message_store + (id, session_id, history) VALUES (%s, %s, %s);""", + (uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))), + ) + except (Unavailable, WriteTimeout, WriteFailure) as error: + logger.error("Unable to write chat history messages to cassandra") + raise error + + def clear(self) -> None: + """Clear session memory from Cassandra""" + + from cassandra import OperationTimedOut, Unavailable + + try: + self.session.execute( + f"DELETE FROM {self.table_name} WHERE session_id = '{self.session_id}';" + ) + except (Unavailable, OperationTimedOut) as error: + logger.error("Unable to clear chat history messages from cassandra") + raise error + + def __del__(self) -> None: + if self.session: + self.session.shutdown() + if self.cluster: + self.cluster.shutdown() diff --git a/pyproject.toml b/pyproject.toml index 92948f8fe8a..08aafc139d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,8 +143,10 @@ promptlayer = "^0.1.80" tair = "^1.3.3" wikipedia = "^1" pymongo = "^4.3.3" +cassandra-driver = "^3.27.0" arxiv = "^1.4" + [tool.poetry.group.lint.dependencies] ruff = "^0.0.249" types-toml = "^0.10.8.1" diff --git a/tests/integration_tests/memory/test_cassandra.py b/tests/integration_tests/memory/test_cassandra.py new file mode 100644 index 00000000000..fc3f7f684bc --- /dev/null +++ b/tests/integration_tests/memory/test_cassandra.py @@ -0,0 +1,42 @@ +import json +import os + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories.cassandra import ( + CassandraChatMessageHistory, +) +from langchain.schema import _message_to_dict + +# Replace these with your cassandra contact points +contact_points = ( + os.environ["CONTACT_POINTS"].split(",") + if "CONTACT_POINTS" in os.environ + else ["cassandra"] +) + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + # setup cassandra as a message store + message_history = CassandraChatMessageHistory( + contact_points=contact_points, session_id="test-session" + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Cassandra, so the next test run won't pick it up + memory.chat_memory.clear() + + assert memory.chat_memory.messages == []