Astra DB, chat message history (#13836)

This PR adds a chat message history component that uses Astra DB for
persistence through the JSON API.
The `astrapy` package is required for this class to work.

I have added tests and a small notebook, and updated the relevant
references in the other docs pages.

(@rlancemartin this is the counterpart of the Cassandra equivalent class
you so helpfully reviewed back at the end of June)

Thank you!
This commit is contained in:
Stefano Lottini
2023-11-25 03:12:29 +01:00
committed by GitHub
parent 58f7e109ac
commit 272df9dcae
7 changed files with 386 additions and 0 deletions

View File

@@ -32,6 +32,7 @@ from langchain.memory.buffer import (
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import (
AstraDBChatMessageHistory,
CassandraChatMessageHistory,
ChatMessageHistory,
CosmosDBChatMessageHistory,
@@ -68,6 +69,7 @@ from langchain.memory.vectorstore import VectorStoreRetrieverMemory
from langchain.memory.zep_memory import ZepMemory
__all__ = [
"AstraDBChatMessageHistory",
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CombinedMemory",

View File

@@ -1,3 +1,6 @@
from langchain.memory.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
@@ -31,6 +34,7 @@ from langchain.memory.chat_message_histories.xata import XataChatMessageHistory
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
__all__ = [
"AstraDBChatMessageHistory",
"ChatMessageHistory",
"CassandraChatMessageHistory",
"CosmosDBChatMessageHistory",

View File

@@ -0,0 +1,114 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations
import json
import time
import typing
from typing import List, Optional
if typing.TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
DEFAULT_COLLECTION_NAME = "langchain_message_store"
class AstraDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Astra DB.
Args (only keyword-arguments accepted):
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name (str): name of the Astra DB collection to create/use.
token (Optional[str]): API token for Astra DB usage.
api_endpoint (Optional[str]): full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client (Optional[Any]): *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
namespace (Optional[str]): namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[LibAstraDB] = None, # type 'astrapy.db.AstraDB'
namespace: Optional[str] = None,
) -> None:
"""Create an Astra DB chat message history."""
try:
from astrapy.db import AstraDB as LibAstraDB
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' to AstraDB if passing "
"'token' and 'api_endpoint'."
)
self.session_id = session_id
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
if astra_db_client is not None:
self.astra_db = astra_db_client
else:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
self.collection = self.astra_db.create_collection(self.collection_name)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all session messages from DB"""
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Write a message to the table"""
self.collection.insert_one(
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
)
def clear(self) -> None:
"""Clear session memory from DB"""
self.collection.delete_many(filter={"session_id": self.session_id})

View File

@@ -0,0 +1,104 @@
import os
from typing import Iterable
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="function")
def history1() -> Iterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history1
history1.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
def history2() -> Iterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history2
history2.astra_db.delete_collection("langchain_cmh_test")
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=history1,
return_messages=True,
)
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
messages = memory.chat_memory.messages
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_separate_session_ids(
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=history2,
return_messages=True,
)
memory1.chat_memory.add_ai_message("Just saying.")
assert memory2.chat_memory.messages == []
memory2.chat_memory.clear()
assert memory1.chat_memory.messages != []
memory1.chat_memory.clear()
assert memory1.chat_memory.messages == []

View File

@@ -1,6 +1,7 @@
from langchain.memory import __all__
EXPECTED_ALL = [
"AstraDBChatMessageHistory",
"CassandraChatMessageHistory",
"ChatMessageHistory",
"CombinedMemory",