mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Astra DB, chat message history (#13836)
This PR adds a chat message history component that uses Astra DB for persistence through the JSON API. The `astrapy` package is required for this class to work. I have added tests and a small notebook, and updated the relevant references in the other docs pages. (@rlancemartin this is the counterpart of the Cassandra equivalent class you so helpfully reviewed back at the end of June) Thank you!
This commit is contained in:
parent
58f7e109ac
commit
272df9dcae
147
docs/docs/integrations/memory/astradb_chat_message_history.ipynb
Normal file
147
docs/docs/integrations/memory/astradb_chat_message_history.ipynb
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "90cd3ded",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Astra DB \n",
|
||||||
|
"\n",
|
||||||
|
"> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Cassandra and made conveniently available through an easy-to-use JSON API.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Astra DB to store chat message history."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "f507f58b-bf22-4a48-8daf-68d869bcd1ba",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Setting up\n",
|
||||||
|
"\n",
|
||||||
|
"To run this notebook you need a running Astra DB. Get the connection secrets on your Astra dashboard:\n",
|
||||||
|
"\n",
|
||||||
|
"- the API Endpoint looks like `https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com`;\n",
|
||||||
|
"- the Token looks like `AstraCS:6gBhNmsk135...`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d7092199",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install --quiet \"astrapy>=0.6.2\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e3d97b65",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Set up the database connection parameters and secrets"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "163d97f0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdin",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ASTRA_DB_API_ENDPOINT = https://01234567-89ab-cdef-0123-456789abcdef-us-east1.apps.astra.datastax.com\n",
|
||||||
|
"ASTRA_DB_APPLICATION_TOKEN = ········\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"ASTRA_DB_API_ENDPOINT = input(\"ASTRA_DB_API_ENDPOINT = \")\n",
|
||||||
|
"ASTRA_DB_APPLICATION_TOKEN = getpass.getpass(\"ASTRA_DB_APPLICATION_TOKEN = \")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "55860b2d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "36c163e8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Example"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d15e3302",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.memory import AstraDBChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"message_history = AstraDBChatMessageHistory(\n",
|
||||||
|
" session_id=\"test-session\",\n",
|
||||||
|
" api_endpoint=ASTRA_DB_API_ENDPOINT,\n",
|
||||||
|
" token=ASTRA_DB_APPLICATION_TOKEN,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"message_history.add_user_message(\"hi!\")\n",
|
||||||
|
"\n",
|
||||||
|
"message_history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "64fc465e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[HumanMessage(content='hi!'), AIMessage(content='whats up?')]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"message_history.messages"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -30,6 +30,20 @@ vector_store = AstraDB(
|
|||||||
Learn more in the [example notebook](/docs/integrations/vectorstores/astradb).
|
Learn more in the [example notebook](/docs/integrations/vectorstores/astradb).
|
||||||
|
|
||||||
|
|
||||||
|
### Memory
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.memory import AstraDBChatMessageHistory
|
||||||
|
message_history = AstraDBChatMessageHistory(
|
||||||
|
session_id="test-session"
|
||||||
|
api_endpoint="...",
|
||||||
|
token="...",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Learn more in the [example notebook](/docs/integrations/memory/astradb_chat_message_history).
|
||||||
|
|
||||||
|
|
||||||
## Apache Cassandra and Astra DB through CQL
|
## Apache Cassandra and Astra DB through CQL
|
||||||
|
|
||||||
> [Cassandra](https://cassandra.apache.org/) is a NoSQL, row-oriented, highly scalable and highly available database.
|
> [Cassandra](https://cassandra.apache.org/) is a NoSQL, row-oriented, highly scalable and highly available database.
|
||||||
|
@ -32,6 +32,7 @@ from langchain.memory.buffer import (
|
|||||||
)
|
)
|
||||||
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
||||||
from langchain.memory.chat_message_histories import (
|
from langchain.memory.chat_message_histories import (
|
||||||
|
AstraDBChatMessageHistory,
|
||||||
CassandraChatMessageHistory,
|
CassandraChatMessageHistory,
|
||||||
ChatMessageHistory,
|
ChatMessageHistory,
|
||||||
CosmosDBChatMessageHistory,
|
CosmosDBChatMessageHistory,
|
||||||
@ -68,6 +69,7 @@ from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
|||||||
from langchain.memory.zep_memory import ZepMemory
|
from langchain.memory.zep_memory import ZepMemory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AstraDBChatMessageHistory",
|
||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
"ChatMessageHistory",
|
"ChatMessageHistory",
|
||||||
"CombinedMemory",
|
"CombinedMemory",
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from langchain.memory.chat_message_histories.astradb import (
|
||||||
|
AstraDBChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.cassandra import (
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
CassandraChatMessageHistory,
|
CassandraChatMessageHistory,
|
||||||
)
|
)
|
||||||
@ -31,6 +34,7 @@ from langchain.memory.chat_message_histories.xata import XataChatMessageHistory
|
|||||||
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
|
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AstraDBChatMessageHistory",
|
||||||
"ChatMessageHistory",
|
"ChatMessageHistory",
|
||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
|
@ -0,0 +1,114 @@
|
|||||||
|
"""Astra DB - based chat message history, based on astrapy."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from astrapy.db import AstraDB as LibAstraDB
|
||||||
|
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
message_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_COLLECTION_NAME = "langchain_message_store"
|
||||||
|
|
||||||
|
|
||||||
|
class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history that stores history in Astra DB.
|
||||||
|
|
||||||
|
Args (only keyword-arguments accepted):
|
||||||
|
session_id: arbitrary key that is used to store the messages
|
||||||
|
of a single chat session.
|
||||||
|
collection_name (str): name of the Astra DB collection to create/use.
|
||||||
|
token (Optional[str]): API token for Astra DB usage.
|
||||||
|
api_endpoint (Optional[str]): full URL to the API endpoint,
|
||||||
|
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||||
|
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
|
||||||
|
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||||
|
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||||
|
collection is created. Defaults to the database's "default namespace".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
api_endpoint: Optional[str] = None,
|
||||||
|
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
|
||||||
|
namespace: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create an Astra DB chat message history."""
|
||||||
|
try:
|
||||||
|
from astrapy.db import AstraDB as LibAstraDB
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import a recent astrapy python package. "
|
||||||
|
"Please install it with `pip install --upgrade astrapy`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conflicting-arg checks:
|
||||||
|
if astra_db_client is not None:
|
||||||
|
if token is not None or api_endpoint is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||||
|
"'token' and 'api_endpoint'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session_id = session_id
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.token = token
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
self.namespace = namespace
|
||||||
|
if astra_db_client is not None:
|
||||||
|
self.astra_db = astra_db_client
|
||||||
|
else:
|
||||||
|
self.astra_db = LibAstraDB(
|
||||||
|
token=self.token,
|
||||||
|
api_endpoint=self.api_endpoint,
|
||||||
|
namespace=self.namespace,
|
||||||
|
)
|
||||||
|
self.collection = self.astra_db.create_collection(self.collection_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Retrieve all session messages from DB"""
|
||||||
|
message_blobs = [
|
||||||
|
doc["body_blob"]
|
||||||
|
for doc in sorted(
|
||||||
|
self.collection.paginated_find(
|
||||||
|
filter={
|
||||||
|
"session_id": self.session_id,
|
||||||
|
},
|
||||||
|
projection={
|
||||||
|
"timestamp": 1,
|
||||||
|
"body_blob": 1,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
key=lambda _doc: _doc["timestamp"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||||
|
messages = messages_from_dict(items)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Write a message to the table"""
|
||||||
|
self.collection.insert_one(
|
||||||
|
{
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"body_blob": json.dumps(message_to_dict(message)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory from DB"""
|
||||||
|
self.collection.delete_many(filter={"session_id": self.session_id})
|
104
libs/langchain/tests/integration_tests/memory/test_astradb.py
Normal file
104
libs/langchain/tests/integration_tests/memory/test_astradb.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories.astradb import (
|
||||||
|
AstraDBChatMessageHistory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_env_vars() -> bool:
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||||
|
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def history1() -> Iterable[AstraDBChatMessageHistory]:
|
||||||
|
history1 = AstraDBChatMessageHistory(
|
||||||
|
session_id="session-test-1",
|
||||||
|
collection_name="langchain_cmh_test",
|
||||||
|
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||||
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||||
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
|
)
|
||||||
|
yield history1
|
||||||
|
history1.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def history2() -> Iterable[AstraDBChatMessageHistory]:
|
||||||
|
history2 = AstraDBChatMessageHistory(
|
||||||
|
session_id="session-test-2",
|
||||||
|
collection_name="langchain_cmh_test",
|
||||||
|
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||||
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||||
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
|
)
|
||||||
|
yield history2
|
||||||
|
history2.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("astrapy")
|
||||||
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
|
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz",
|
||||||
|
chat_memory=history1,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
expected = [
|
||||||
|
AIMessage(content="This is me, the AI"),
|
||||||
|
HumanMessage(content="This is me, the human"),
|
||||||
|
]
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
# clear the store
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("astrapy")
|
||||||
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
|
def test_memory_separate_session_ids(
|
||||||
|
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
|
||||||
|
) -> None:
|
||||||
|
"""Test that separate session IDs do not share entries."""
|
||||||
|
memory1 = ConversationBufferMemory(
|
||||||
|
memory_key="mk1",
|
||||||
|
chat_memory=history1,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
memory2 = ConversationBufferMemory(
|
||||||
|
memory_key="mk2",
|
||||||
|
chat_memory=history2,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory1.chat_memory.add_ai_message("Just saying.")
|
||||||
|
|
||||||
|
assert memory2.chat_memory.messages == []
|
||||||
|
|
||||||
|
memory2.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory1.chat_memory.messages != []
|
||||||
|
|
||||||
|
memory1.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory1.chat_memory.messages == []
|
@ -1,6 +1,7 @@
|
|||||||
from langchain.memory import __all__
|
from langchain.memory import __all__
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
|
"AstraDBChatMessageHistory",
|
||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
"ChatMessageHistory",
|
"ChatMessageHistory",
|
||||||
"CombinedMemory",
|
"CombinedMemory",
|
||||||
|
Loading…
Reference in New Issue
Block a user