mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
# Cassandra support for chat history ### Description - Store chat messages in cassandra ### Dependency - cassandra-driver - Python Module ## Before submitting - Added Integration Test ## Who can review? @hwchase17 @agola11 # Your PR Title (What it does) <!-- Thank you for contributing to LangChain! Your PR will appear in our next release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting <!-- If you're adding a new integration, include an integration test and an example notebook showing its use! --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> Co-authored-by: Jinto Jose <129657162+jj701@users.noreply.github.com>
This commit is contained in:
parent
c4c7936caa
commit
a7af32c274
@ -0,0 +1,91 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "91c6a7ef",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cassandra Chat Message History\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Cassandra to store chat message history.\n",
|
||||||
|
"\n",
|
||||||
|
"Cassandra is a distributed database that is well suited for storing large amounts of data. \n",
|
||||||
|
"\n",
|
||||||
|
"It is a good choice for storing chat message history because it is easy to scale and can handle a large number of writes.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "47a601d2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# List of contact points to try connecting to Cassandra cluster.\n",
|
||||||
|
"contact_points = [\"cassandra\"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "d15e3302",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.memory import CassandraChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"message_history = CassandraChatMessageHistory(\n",
|
||||||
|
" contact_points=contact_points, session_id=\"test-session\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"message_history.add_user_message(\"hi!\")\n",
|
||||||
|
"\n",
|
||||||
|
"message_history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "64fc465e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n",
|
||||||
|
" AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"message_history.messages"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -3,6 +3,9 @@ from langchain.memory.buffer import (
|
|||||||
ConversationStringBufferMemory,
|
ConversationStringBufferMemory,
|
||||||
)
|
)
|
||||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||||
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
|
CassandraChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||||
@ -46,4 +49,5 @@ __all__ = [
|
|||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"FileChatMessageHistory",
|
"FileChatMessageHistory",
|
||||||
"MongoDBChatMessageHistory",
|
"MongoDBChatMessageHistory",
|
||||||
|
"CassandraChatMessageHistory",
|
||||||
]
|
]
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
|
CassandraChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||||
@ -18,4 +21,5 @@ __all__ = [
|
|||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"FirestoreChatMessageHistory",
|
"FirestoreChatMessageHistory",
|
||||||
"MongoDBChatMessageHistory",
|
"MongoDBChatMessageHistory",
|
||||||
|
"CassandraChatMessageHistory",
|
||||||
]
|
]
|
||||||
|
186
langchain/memory/chat_message_histories/cassandra.py
Normal file
186
langchain/memory/chat_message_histories/cassandra.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseChatMessageHistory,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
_message_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_KEYSPACE_NAME = "chat_history"
|
||||||
|
DEFAULT_TABLE_NAME = "message_store"
|
||||||
|
DEFAULT_USERNAME = "cassandra"
|
||||||
|
DEFAULT_PASSWORD = "cassandra"
|
||||||
|
DEFAULT_PORT = 9042
|
||||||
|
|
||||||
|
|
||||||
|
class CassandraChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history that stores history in Cassandra.
|
||||||
|
Args:
|
||||||
|
contact_points: list of ips to connect to Cassandra cluster
|
||||||
|
session_id: arbitrary key that is used to store the messages
|
||||||
|
of a single chat session.
|
||||||
|
port: port to connect to Cassandra cluster
|
||||||
|
username: username to connect to Cassandra cluster
|
||||||
|
password: password to connect to Cassandra cluster
|
||||||
|
keyspace_name: name of the keyspace to use
|
||||||
|
table_name: name of the table to use
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
contact_points: List[str],
|
||||||
|
session_id: str,
|
||||||
|
port: int = DEFAULT_PORT,
|
||||||
|
username: str = DEFAULT_USERNAME,
|
||||||
|
password: str = DEFAULT_PASSWORD,
|
||||||
|
keyspace_name: str = DEFAULT_KEYSPACE_NAME,
|
||||||
|
table_name: str = DEFAULT_TABLE_NAME,
|
||||||
|
):
|
||||||
|
self.contact_points = contact_points
|
||||||
|
self.session_id = session_id
|
||||||
|
self.port = port
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
self.keyspace_name = keyspace_name
|
||||||
|
self.table_name = table_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cassandra import (
|
||||||
|
AuthenticationFailed,
|
||||||
|
OperationTimedOut,
|
||||||
|
UnresolvableContactPoints,
|
||||||
|
)
|
||||||
|
from cassandra.cluster import Cluster, PlainTextAuthProvider
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import cassandra-driver python package. "
|
||||||
|
"Please install it with `pip install cassandra-driver`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cluster: Cluster = Cluster(
|
||||||
|
contact_points,
|
||||||
|
port=port,
|
||||||
|
auth_provider=PlainTextAuthProvider(
|
||||||
|
username=self.username, password=self.password
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session = self.cluster.connect()
|
||||||
|
except (
|
||||||
|
AuthenticationFailed,
|
||||||
|
UnresolvableContactPoints,
|
||||||
|
OperationTimedOut,
|
||||||
|
) as error:
|
||||||
|
logger.error(
|
||||||
|
"Unable to establish connection with \
|
||||||
|
cassandra chat message history database"
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
self._prepare_cassandra()
|
||||||
|
|
||||||
|
def _prepare_cassandra(self) -> None:
|
||||||
|
"""Create the keyspace and table if they don't exist yet"""
|
||||||
|
|
||||||
|
from cassandra import OperationTimedOut, Unavailable
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
f"""CREATE KEYSPACE IF NOT EXISTS
|
||||||
|
{self.keyspace_name} WITH REPLICATION =
|
||||||
|
{{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }};"""
|
||||||
|
)
|
||||||
|
except (OperationTimedOut, Unavailable) as error:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to create cassandra \
|
||||||
|
chat message history keyspace: {self.keyspace_name}."
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
self.session.set_keyspace(self.keyspace_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
f"""CREATE TABLE IF NOT EXISTS
|
||||||
|
{self.table_name} (id UUID, session_id varchar,
|
||||||
|
history text, PRIMARY KEY ((session_id), id) );"""
|
||||||
|
)
|
||||||
|
except (OperationTimedOut, Unavailable) as error:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to create cassandra \
|
||||||
|
chat message history table: {self.table_name}"
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Retrieve the messages from Cassandra"""
|
||||||
|
from cassandra import ReadFailure, ReadTimeout, Unavailable
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = self.session.execute(
|
||||||
|
f"""SELECT * FROM {self.table_name}
|
||||||
|
WHERE session_id = '{self.session_id}' ;"""
|
||||||
|
)
|
||||||
|
except (Unavailable, ReadTimeout, ReadFailure) as error:
|
||||||
|
logger.error("Unable to Retreive chat history messages from cassadra")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
items = [json.loads(row.history) for row in rows]
|
||||||
|
else:
|
||||||
|
items = []
|
||||||
|
|
||||||
|
messages = messages_from_dict(items)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def add_user_message(self, message: str) -> None:
|
||||||
|
self.append(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str) -> None:
|
||||||
|
self.append(AIMessage(content=message))
|
||||||
|
|
||||||
|
def append(self, message: BaseMessage) -> None:
|
||||||
|
"""Append the message to the record in Cassandra"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from cassandra import Unavailable, WriteFailure, WriteTimeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
"""INSERT INTO message_store
|
||||||
|
(id, session_id, history) VALUES (%s, %s, %s);""",
|
||||||
|
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
|
||||||
|
)
|
||||||
|
except (Unavailable, WriteTimeout, WriteFailure) as error:
|
||||||
|
logger.error("Unable to write chat history messages to cassandra")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory from Cassandra"""
|
||||||
|
|
||||||
|
from cassandra import OperationTimedOut, Unavailable
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.session.execute(
|
||||||
|
f"DELETE FROM {self.table_name} WHERE session_id = '{self.session_id}';"
|
||||||
|
)
|
||||||
|
except (Unavailable, OperationTimedOut) as error:
|
||||||
|
logger.error("Unable to clear chat history messages from cassandra")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if self.session:
|
||||||
|
self.session.shutdown()
|
||||||
|
if self.cluster:
|
||||||
|
self.cluster.shutdown()
|
@ -143,8 +143,10 @@ promptlayer = "^0.1.80"
|
|||||||
tair = "^1.3.3"
|
tair = "^1.3.3"
|
||||||
wikipedia = "^1"
|
wikipedia = "^1"
|
||||||
pymongo = "^4.3.3"
|
pymongo = "^4.3.3"
|
||||||
|
cassandra-driver = "^3.27.0"
|
||||||
arxiv = "^1.4"
|
arxiv = "^1.4"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
ruff = "^0.0.249"
|
ruff = "^0.0.249"
|
||||||
types-toml = "^0.10.8.1"
|
types-toml = "^0.10.8.1"
|
||||||
|
42
tests/integration_tests/memory/test_cassandra.py
Normal file
42
tests/integration_tests/memory/test_cassandra.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
|
CassandraChatMessageHistory,
|
||||||
|
)
|
||||||
|
from langchain.schema import _message_to_dict
|
||||||
|
|
||||||
|
# Replace these with your cassandra contact points
|
||||||
|
contact_points = (
|
||||||
|
os.environ["CONTACT_POINTS"].split(",")
|
||||||
|
if "CONTACT_POINTS" in os.environ
|
||||||
|
else ["cassandra"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_with_message_store() -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
# setup cassandra as a message store
|
||||||
|
message_history = CassandraChatMessageHistory(
|
||||||
|
contact_points=contact_points, session_id="test-session"
|
||||||
|
)
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
# get the message history from the memory store and turn it into a json
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# remove the record from Cassandra, so the next test run won't pick it up
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue
Block a user