mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 06:55:09 +00:00
community[minor]: Add async methods to AstraDBChatMessageHistory (#17572)
This commit is contained in:
parent
ff1f985a2a
commit
387cacb881
@ -3,9 +3,12 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
|
||||
from langchain_community.utilities.astradb import _AstraDBEnvironment
|
||||
from langchain_community.utilities.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB
|
||||
@ -45,24 +48,30 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
) -> None:
|
||||
"""Create an Astra DB chat message history."""
|
||||
astra_env = _AstraDBEnvironment(
|
||||
self.astra_env = _AstraDBCollectionEnvironment(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
|
||||
self.collection = self.astra_db.create_collection(collection_name)
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
self.session_id = session_id
|
||||
self.collection_name = collection_name
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
def messages(self) -> List[BaseMessage]:
|
||||
"""Retrieve all session messages from DB"""
|
||||
self.astra_env.ensure_db_setup()
|
||||
message_blobs = [
|
||||
doc["body_blob"]
|
||||
for doc in sorted(
|
||||
@ -82,16 +91,63 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
@messages.setter
|
||||
def messages(self, messages: List[BaseMessage]) -> None:
|
||||
raise NotImplementedError("Use add_messages instead")
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
"""Retrieve all session messages from DB"""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs = self.async_collection.paginated_find(
|
||||
filter={
|
||||
"session_id": self.session_id,
|
||||
},
|
||||
projection={
|
||||
"timestamp": 1,
|
||||
"body_blob": 1,
|
||||
},
|
||||
)
|
||||
sorted_docs = sorted(
|
||||
[doc async for doc in docs],
|
||||
key=lambda _doc: _doc["timestamp"],
|
||||
)
|
||||
message_blobs = [doc["body_blob"] for doc in sorted_docs]
|
||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
"""Write a message to the table"""
|
||||
self.collection.insert_one(
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs = [
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"session_id": self.session_id,
|
||||
"body_blob": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
self.collection.chunked_insert_many(docs)
|
||||
|
||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
"""Write a message to the table"""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs = [
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"session_id": self.session_id,
|
||||
"body_blob": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
for message in messages
|
||||
]
|
||||
await self.async_collection.chunked_insert_many(docs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from DB"""
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.delete_many(filter={"session_id": self.session_id})
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear session memory from DB"""
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.delete_many(filter={"session_id": self.session_id})
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Iterable
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_message_histories.astradb import (
|
||||
AstraDBChatMessageHistory,
|
||||
)
|
||||
from langchain_community.utilities.astradb import SetupMode
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
@ -29,7 +30,7 @@ def history1() -> Iterable[AstraDBChatMessageHistory]:
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history1
|
||||
history1.astra_db.delete_collection("langchain_cmh_test")
|
||||
history1.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -42,7 +43,35 @@ def history2() -> Iterable[AstraDBChatMessageHistory]:
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history2
|
||||
history2.astra_db.delete_collection("langchain_cmh_test")
|
||||
history2.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history1 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-1",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history1
|
||||
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history2 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-2",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history2
|
||||
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@ -58,8 +87,12 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
memory.chat_memory.add_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
expected = [
|
||||
@ -74,6 +107,41 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_with_message_store_async(
|
||||
async_history1: AstraDBChatMessageHistory,
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
# add some messages
|
||||
await memory.chat_memory.aadd_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = await memory.chat_memory.aget_messages()
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
await memory.chat_memory.aclear()
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
def test_memory_separate_session_ids(
|
||||
@ -91,7 +159,7 @@ def test_memory_separate_session_ids(
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
memory1.chat_memory.add_ai_message("Just saying.")
|
||||
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert memory2.chat_memory.messages == []
|
||||
|
||||
@ -102,3 +170,33 @@ def test_memory_separate_session_ids(
|
||||
memory1.chat_memory.clear()
|
||||
|
||||
assert memory1.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_separate_session_ids_async(
|
||||
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
|
||||
) -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=async_history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert await memory2.chat_memory.aget_messages() == []
|
||||
|
||||
await memory2.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() != []
|
||||
|
||||
await memory1.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() == []
|
||||
|
Loading…
Reference in New Issue
Block a user