diff --git a/docs/docs/integrations/memory/astradb_chat_message_history.ipynb b/docs/docs/integrations/memory/astradb_chat_message_history.ipynb new file mode 100644 index 00000000000..52159c44f9b --- /dev/null +++ b/docs/docs/integrations/memory/astradb_chat_message_history.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "90cd3ded", + "metadata": {}, + "source": [ + "# Astra DB \n", + "\n", + "> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Cassandra and made conveniently available through an easy-to-use JSON API.\n", + "\n", + "This notebook goes over how to use Astra DB to store chat message history." + ] + }, + { + "cell_type": "markdown", + "id": "f507f58b-bf22-4a48-8daf-68d869bcd1ba", + "metadata": {}, + "source": [ + "## Setting up\n", + "\n", + "To run this notebook you need a running Astra DB. Get the connection secrets on your Astra dashboard:\n", + "\n", + "- the API Endpoint looks like `https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com`;\n", + "- the Token looks like `AstraCS:6gBhNmsk135...`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7092199", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --quiet \"astrapy>=0.6.2\"" + ] + }, + { + "cell_type": "markdown", + "id": "e3d97b65", + "metadata": {}, + "source": [ + "### Set up the database connection parameters and secrets" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "163d97f0", + "metadata": {}, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + "ASTRA_DB_API_ENDPOINT = https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com\n", + "ASTRA_DB_APPLICATION_TOKEN = ········\n" + ] + } + ], + "source": [ + "import getpass\n", + "\n", + "ASTRA_DB_API_ENDPOINT = input(\"ASTRA_DB_API_ENDPOINT = \")\n", + "ASTRA_DB_APPLICATION_TOKEN = getpass.getpass(\"ASTRA_DB_APPLICATION_TOKEN = \")" + ] + }, + { + "cell_type": "markdown", + "id": "55860b2d", + "metadata": {}, + "source": [ + "Depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object." + ] + }, + { + "cell_type": "markdown", + "id": "36c163e8", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import AstraDBChatMessageHistory\n", + "\n", + "message_history = AstraDBChatMessageHistory(\n", + " session_id=\"test-session\",\n", + " api_endpoint=ASTRA_DB_API_ENDPOINT,\n", + " token=ASTRA_DB_APPLICATION_TOKEN,\n", + ")\n", + "\n", + "message_history.add_user_message(\"hi!\")\n", + "\n", + "message_history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "64fc465e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content='hi!'), AIMessage(content='whats up?')]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "message_history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/providers/astradb.mdx b/docs/docs/integrations/providers/astradb.mdx index 76b6864dd78..6fcb1fc6050 100644 --- a/docs/docs/integrations/providers/astradb.mdx +++ b/docs/docs/integrations/providers/astradb.mdx @@ -30,6 +30,20 @@ vector_store = AstraDB( Learn more in the [example notebook](/docs/integrations/vectorstores/astradb). +### Memory + +```python +from langchain.memory import AstraDBChatMessageHistory +message_history = AstraDBChatMessageHistory( + session_id="test-session" + api_endpoint="...", + token="...", +) +``` + +Learn more in the [example notebook](/docs/integrations/memory/astradb_chat_message_history). + + ## Apache Cassandra and Astra DB through CQL > [Cassandra](https://cassandra.apache.org/) is a NoSQL, row-oriented, highly scalable and highly available database. diff --git a/libs/langchain/langchain/memory/__init__.py b/libs/langchain/langchain/memory/__init__.py index a7049e98b93..c10bbf492c9 100644 --- a/libs/langchain/langchain/memory/__init__.py +++ b/libs/langchain/langchain/memory/__init__.py @@ -32,6 +32,7 @@ from langchain.memory.buffer import ( ) from langchain.memory.buffer_window import ConversationBufferWindowMemory from langchain.memory.chat_message_histories import ( + AstraDBChatMessageHistory, CassandraChatMessageHistory, ChatMessageHistory, CosmosDBChatMessageHistory, @@ -68,6 +69,7 @@ from langchain.memory.vectorstore import VectorStoreRetrieverMemory from langchain.memory.zep_memory import ZepMemory __all__ = [ + "AstraDBChatMessageHistory", "CassandraChatMessageHistory", "ChatMessageHistory", "CombinedMemory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index a1497e8a122..83fc7fa519a 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -1,3 +1,6 @@ +from langchain.memory.chat_message_histories.astradb import ( + AstraDBChatMessageHistory, +) from langchain.memory.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) @@ -31,6 +34,7 @@ from langchain.memory.chat_message_histories.xata import XataChatMessageHistory from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory __all__ = [ + "AstraDBChatMessageHistory", "ChatMessageHistory", "CassandraChatMessageHistory", "CosmosDBChatMessageHistory", diff --git a/libs/langchain/langchain/memory/chat_message_histories/astradb.py b/libs/langchain/langchain/memory/chat_message_histories/astradb.py new file mode 100644 index 00000000000..27e4dc5c936 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/astradb.py @@ -0,0 +1,114 @@ +"""Astra DB - based chat message history, based on astrapy.""" +from __future__ import annotations + +import json +import time +import typing +from typing import List, Optional + +if typing.TYPE_CHECKING: + from astrapy.db import AstraDB as LibAstraDB + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import ( + BaseMessage, + message_to_dict, + messages_from_dict, +) + +DEFAULT_COLLECTION_NAME = "langchain_message_store" + + +class AstraDBChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Astra DB. + + Args (only keyword-arguments accepted): + session_id: arbitrary key that is used to store the messages + of a single chat session. + collection_name (str): name of the Astra DB collection to create/use. + token (Optional[str]): API token for Astra DB usage. + api_endpoint (Optional[str]): full URL to the API endpoint, + such as "https://-us-east1.apps.astra.datastax.com". + astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AstraDB' instance. + namespace (Optional[str]): namespace (aka keyspace) where the + collection is created. Defaults to the database's "default namespace". + """ + + def __init__( + self, + *, + session_id: str, + collection_name: str = DEFAULT_COLLECTION_NAME, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB' + namespace: Optional[str] = None, + ) -> None: + """Create an Astra DB chat message history.""" + try: + from astrapy.db import AstraDB as LibAstraDB + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import a recent astrapy python package. " + "Please install it with `pip install --upgrade astrapy`." + ) + + # Conflicting-arg checks: + if astra_db_client is not None: + if token is not None or api_endpoint is not None: + raise ValueError( + "You cannot pass 'astra_db_client' to AstraDB if passing " + "'token' and 'api_endpoint'." + ) + + self.session_id = session_id + self.collection_name = collection_name + self.token = token + self.api_endpoint = api_endpoint + self.namespace = namespace + if astra_db_client is not None: + self.astra_db = astra_db_client + else: + self.astra_db = LibAstraDB( + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace, + ) + self.collection = self.astra_db.create_collection(self.collection_name) + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve all session messages from DB""" + message_blobs = [ + doc["body_blob"] + for doc in sorted( + self.collection.paginated_find( + filter={ + "session_id": self.session_id, + }, + projection={ + "timestamp": 1, + "body_blob": 1, + }, + ), + key=lambda _doc: _doc["timestamp"], + ) + ] + items = [json.loads(message_blob) for message_blob in message_blobs] + messages = messages_from_dict(items) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Write a message to the table""" + self.collection.insert_one( + { + "timestamp": time.time(), + "session_id": self.session_id, + "body_blob": json.dumps(message_to_dict(message)), + } + ) + + def clear(self) -> None: + """Clear session memory from DB""" + self.collection.delete_many(filter={"session_id": self.session_id}) diff --git a/libs/langchain/tests/integration_tests/memory/test_astradb.py b/libs/langchain/tests/integration_tests/memory/test_astradb.py new file mode 100644 index 00000000000..c1753807b55 --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_astradb.py @@ -0,0 +1,104 @@ +import os +from typing import Iterable + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories.astradb import ( + AstraDBChatMessageHistory, +) + + +def _has_env_vars() -> bool: + return all( + [ + "ASTRA_DB_APPLICATION_TOKEN" in os.environ, + "ASTRA_DB_API_ENDPOINT" in os.environ, + ] + ) + + +@pytest.fixture(scope="function") +def history1() -> Iterable[AstraDBChatMessageHistory]: + history1 = AstraDBChatMessageHistory( + session_id="session-test-1", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + yield history1 + history1.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.fixture(scope="function") +def history2() -> Iterable[AstraDBChatMessageHistory]: + history2 = AstraDBChatMessageHistory( + session_id="session-test-2", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + ) + yield history2 + history2.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: + """Test the memory with a message store.""" + memory = ConversationBufferMemory( + memory_key="baz", + chat_memory=history1, + return_messages=True, + ) + + assert memory.chat_memory.messages == [] + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + messages = memory.chat_memory.messages + expected = [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + assert messages == expected + + # clear the store + memory.chat_memory.clear() + + assert memory.chat_memory.messages == [] + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +def test_memory_separate_session_ids( + history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory +) -> None: + """Test that separate session IDs do not share entries.""" + memory1 = ConversationBufferMemory( + memory_key="mk1", + chat_memory=history1, + return_messages=True, + ) + memory2 = ConversationBufferMemory( + memory_key="mk2", + chat_memory=history2, + return_messages=True, + ) + + memory1.chat_memory.add_ai_message("Just saying.") + + assert memory2.chat_memory.messages == [] + + memory2.chat_memory.clear() + + assert memory1.chat_memory.messages != [] + + memory1.chat_memory.clear() + + assert memory1.chat_memory.messages == [] diff --git a/libs/langchain/tests/unit_tests/memory/test_imports.py b/libs/langchain/tests/unit_tests/memory/test_imports.py index f7ed3a9d4e6..e81db420c64 100644 --- a/libs/langchain/tests/unit_tests/memory/test_imports.py +++ b/libs/langchain/tests/unit_tests/memory/test_imports.py @@ -1,6 +1,7 @@ from langchain.memory import __all__ EXPECTED_ALL = [ + "AstraDBChatMessageHistory", "CassandraChatMessageHistory", "ChatMessageHistory", "CombinedMemory",