From 387cacb8816fd4682a301824488b2cb51241fe40 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 15 Feb 2024 15:48:42 +0100 Subject: [PATCH] community[minor]: Add async methods to AstraDBChatMessageHistory (#17572) --- .../chat_message_histories/astradb.py | 74 ++++++++++-- .../integration_tests/memory/test_astradb.py | 110 +++++++++++++++++- 2 files changed, 169 insertions(+), 15 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/astradb.py b/libs/community/langchain_community/chat_message_histories/astradb.py index 5c90d37fc00..f820480ff26 100644 --- a/libs/community/langchain_community/chat_message_histories/astradb.py +++ b/libs/community/langchain_community/chat_message_histories/astradb.py @@ -3,9 +3,12 @@ from __future__ import annotations import json import time -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Sequence -from langchain_community.utilities.astradb import _AstraDBEnvironment +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) if TYPE_CHECKING: from astrapy.db import AstraDB @@ -45,24 +48,30 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory): api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + pre_delete_collection: bool = False, ) -> None: """Create an Astra DB chat message history.""" - astra_env = _AstraDBEnvironment( + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, ) - self.astra_db = astra_env.astra_db - self.collection = self.astra_db.create_collection(collection_name) + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection self.session_id = session_id self.collection_name = collection_name @property - def messages(self) -> List[BaseMessage]: # type: ignore + def messages(self) -> List[BaseMessage]: """Retrieve all session messages from DB""" + self.astra_env.ensure_db_setup() message_blobs = [ doc["body_blob"] for doc in sorted( @@ -82,16 +91,63 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory): messages = messages_from_dict(items) return messages - def add_message(self, message: BaseMessage) -> None: + @messages.setter + def messages(self, messages: List[BaseMessage]) -> None: + raise NotImplementedError("Use add_messages instead") + + async def aget_messages(self) -> List[BaseMessage]: + """Retrieve all session messages from DB""" + await self.astra_env.aensure_db_setup() + docs = self.async_collection.paginated_find( + filter={ + "session_id": self.session_id, + }, + projection={ + "timestamp": 1, + "body_blob": 1, + }, + ) + sorted_docs = sorted( + [doc async for doc in docs], + key=lambda _doc: _doc["timestamp"], + ) + message_blobs = [doc["body_blob"] for doc in sorted_docs] + items = [json.loads(message_blob) for message_blob in message_blobs] + messages = messages_from_dict(items) + return messages + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: """Write a message to the table""" - self.collection.insert_one( + self.astra_env.ensure_db_setup() + docs = [ { "timestamp": time.time(), "session_id": self.session_id, "body_blob": json.dumps(message_to_dict(message)), } - ) + for message in messages + ] + self.collection.chunked_insert_many(docs) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Write a message to the table""" + await self.astra_env.aensure_db_setup() + docs = [ + { + "timestamp": time.time(), + "session_id": self.session_id, + "body_blob": json.dumps(message_to_dict(message)), + } + for message in messages + ] + await self.async_collection.chunked_insert_many(docs) def clear(self) -> None: """Clear session memory from DB""" + self.astra_env.ensure_db_setup() self.collection.delete_many(filter={"session_id": self.session_id}) + + async def aclear(self) -> None: + """Clear session memory from DB""" + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_many(filter={"session_id": self.session_id}) diff --git a/libs/langchain/tests/integration_tests/memory/test_astradb.py b/libs/langchain/tests/integration_tests/memory/test_astradb.py index a8ed9e7574b..4caf39985ec 100644 --- a/libs/langchain/tests/integration_tests/memory/test_astradb.py +++ b/libs/langchain/tests/integration_tests/memory/test_astradb.py @@ -1,10 +1,11 @@ import os -from typing import Iterable +from typing import AsyncIterable, Iterable import pytest from langchain_community.chat_message_histories.astradb import ( AstraDBChatMessageHistory, ) +from langchain_community.utilities.astradb import SetupMode from langchain_core.messages import AIMessage, HumanMessage from langchain.memory import ConversationBufferMemory @@ -29,7 +30,7 @@ def history1() -> Iterable[AstraDBChatMessageHistory]: namespace=os.environ.get("ASTRA_DB_KEYSPACE"), ) yield history1 - history1.astra_db.delete_collection("langchain_cmh_test") + history1.collection.astra_db.delete_collection("langchain_cmh_test") @pytest.fixture(scope="function") @@ -42,7 +43,35 @@ def history2() -> Iterable[AstraDBChatMessageHistory]: namespace=os.environ.get("ASTRA_DB_KEYSPACE"), ) yield history2 - history2.astra_db.delete_collection("langchain_cmh_test") + history2.collection.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.fixture +async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]: + history1 = AstraDBChatMessageHistory( + session_id="async-session-test-1", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + setup_mode=SetupMode.ASYNC, + ) + yield history1 + await history1.async_collection.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.fixture(scope="function") +async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]: + history2 = AstraDBChatMessageHistory( + session_id="async-session-test-2", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + setup_mode=SetupMode.ASYNC, + ) + yield history2 + await history2.async_collection.astra_db.delete_collection("langchain_cmh_test") @pytest.mark.requires("astrapy") @@ -58,8 +87,12 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: assert memory.chat_memory.messages == [] # add some messages - memory.chat_memory.add_ai_message("This is me, the AI") - memory.chat_memory.add_user_message("This is me, the human") + memory.chat_memory.add_messages( + [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + ) messages = memory.chat_memory.messages expected = [ @@ -74,6 +107,41 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: assert memory.chat_memory.messages == [] +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +async def test_memory_with_message_store_async( + async_history1: AstraDBChatMessageHistory, +) -> None: + """Test the memory with a message store.""" + memory = ConversationBufferMemory( + memory_key="baz", + chat_memory=async_history1, + return_messages=True, + ) + + assert await memory.chat_memory.aget_messages() == [] + + # add some messages + await memory.chat_memory.aadd_messages( + [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + ) + + messages = await memory.chat_memory.aget_messages() + expected = [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + assert messages == expected + + # clear the store + await memory.chat_memory.aclear() + + assert await memory.chat_memory.aget_messages() == [] + + @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") def test_memory_separate_session_ids( @@ -91,7 +159,7 @@ def test_memory_separate_session_ids( return_messages=True, ) - memory1.chat_memory.add_ai_message("Just saying.") + memory1.chat_memory.add_messages([AIMessage(content="Just saying.")]) assert memory2.chat_memory.messages == [] @@ -102,3 +170,33 @@ def test_memory_separate_session_ids( memory1.chat_memory.clear() assert memory1.chat_memory.messages == [] + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +async def test_memory_separate_session_ids_async( + async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory +) -> None: + """Test that separate session IDs do not share entries.""" + memory1 = ConversationBufferMemory( + memory_key="mk1", + chat_memory=async_history1, + return_messages=True, + ) + memory2 = ConversationBufferMemory( + memory_key="mk2", + chat_memory=async_history2, + return_messages=True, + ) + + await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")]) + + assert await memory2.chat_memory.aget_messages() == [] + + await memory2.chat_memory.aclear() + + assert await memory1.chat_memory.aget_messages() != [] + + await memory1.chat_memory.aclear() + + assert await memory1.chat_memory.aget_messages() == []