mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
community[minor]: Add async methods to AstraDBChatMessageHistory (#17572)
This commit is contained in:
parent
ff1f985a2a
commit
387cacb881
@ -3,9 +3,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain_community.utilities.astradb import _AstraDBEnvironment
|
from langchain_community.utilities.astradb import (
|
||||||
|
SetupMode,
|
||||||
|
_AstraDBCollectionEnvironment,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from astrapy.db import AstraDB
|
from astrapy.db import AstraDB
|
||||||
@ -45,24 +48,30 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
api_endpoint: Optional[str] = None,
|
api_endpoint: Optional[str] = None,
|
||||||
astra_db_client: Optional[AstraDB] = None,
|
astra_db_client: Optional[AstraDB] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
setup_mode: SetupMode = SetupMode.SYNC,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an Astra DB chat message history."""
|
"""Create an Astra DB chat message history."""
|
||||||
astra_env = _AstraDBEnvironment(
|
self.astra_env = _AstraDBCollectionEnvironment(
|
||||||
|
collection_name=collection_name,
|
||||||
token=token,
|
token=token,
|
||||||
api_endpoint=api_endpoint,
|
api_endpoint=api_endpoint,
|
||||||
astra_db_client=astra_db_client,
|
astra_db_client=astra_db_client,
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
|
setup_mode=setup_mode,
|
||||||
|
pre_delete_collection=pre_delete_collection,
|
||||||
)
|
)
|
||||||
self.astra_db = astra_env.astra_db
|
|
||||||
|
|
||||||
self.collection = self.astra_db.create_collection(collection_name)
|
self.collection = self.astra_env.collection
|
||||||
|
self.async_collection = self.astra_env.async_collection
|
||||||
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
def messages(self) -> List[BaseMessage]:
|
||||||
"""Retrieve all session messages from DB"""
|
"""Retrieve all session messages from DB"""
|
||||||
|
self.astra_env.ensure_db_setup()
|
||||||
message_blobs = [
|
message_blobs = [
|
||||||
doc["body_blob"]
|
doc["body_blob"]
|
||||||
for doc in sorted(
|
for doc in sorted(
|
||||||
@ -82,16 +91,63 @@ class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
@messages.setter
|
||||||
|
def messages(self, messages: List[BaseMessage]) -> None:
|
||||||
|
raise NotImplementedError("Use add_messages instead")
|
||||||
|
|
||||||
|
async def aget_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Retrieve all session messages from DB"""
|
||||||
|
await self.astra_env.aensure_db_setup()
|
||||||
|
docs = self.async_collection.paginated_find(
|
||||||
|
filter={
|
||||||
|
"session_id": self.session_id,
|
||||||
|
},
|
||||||
|
projection={
|
||||||
|
"timestamp": 1,
|
||||||
|
"body_blob": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sorted_docs = sorted(
|
||||||
|
[doc async for doc in docs],
|
||||||
|
key=lambda _doc: _doc["timestamp"],
|
||||||
|
)
|
||||||
|
message_blobs = [doc["body_blob"] for doc in sorted_docs]
|
||||||
|
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||||
|
messages = messages_from_dict(items)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||||
"""Write a message to the table"""
|
"""Write a message to the table"""
|
||||||
self.collection.insert_one(
|
self.astra_env.ensure_db_setup()
|
||||||
|
docs = [
|
||||||
{
|
{
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
"session_id": self.session_id,
|
"session_id": self.session_id,
|
||||||
"body_blob": json.dumps(message_to_dict(message)),
|
"body_blob": json.dumps(message_to_dict(message)),
|
||||||
}
|
}
|
||||||
)
|
for message in messages
|
||||||
|
]
|
||||||
|
self.collection.chunked_insert_many(docs)
|
||||||
|
|
||||||
|
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Write a message to the table"""
|
||||||
|
await self.astra_env.aensure_db_setup()
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"body_blob": json.dumps(message_to_dict(message)),
|
||||||
|
}
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
await self.async_collection.chunked_insert_many(docs)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear session memory from DB"""
|
"""Clear session memory from DB"""
|
||||||
|
self.astra_env.ensure_db_setup()
|
||||||
self.collection.delete_many(filter={"session_id": self.session_id})
|
self.collection.delete_many(filter={"session_id": self.session_id})
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
"""Clear session memory from DB"""
|
||||||
|
await self.astra_env.aensure_db_setup()
|
||||||
|
await self.async_collection.delete_many(filter={"session_id": self.session_id})
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Iterable
|
from typing import AsyncIterable, Iterable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_community.chat_message_histories.astradb import (
|
from langchain_community.chat_message_histories.astradb import (
|
||||||
AstraDBChatMessageHistory,
|
AstraDBChatMessageHistory,
|
||||||
)
|
)
|
||||||
|
from langchain_community.utilities.astradb import SetupMode
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
@ -29,7 +30,7 @@ def history1() -> Iterable[AstraDBChatMessageHistory]:
|
|||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
)
|
)
|
||||||
yield history1
|
yield history1
|
||||||
history1.astra_db.delete_collection("langchain_cmh_test")
|
history1.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
@ -42,7 +43,35 @@ def history2() -> Iterable[AstraDBChatMessageHistory]:
|
|||||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
)
|
)
|
||||||
yield history2
|
yield history2
|
||||||
history2.astra_db.delete_collection("langchain_cmh_test")
|
history2.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||||
|
history1 = AstraDBChatMessageHistory(
|
||||||
|
session_id="async-session-test-1",
|
||||||
|
collection_name="langchain_cmh_test",
|
||||||
|
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||||
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||||
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
|
setup_mode=SetupMode.ASYNC,
|
||||||
|
)
|
||||||
|
yield history1
|
||||||
|
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||||
|
history2 = AstraDBChatMessageHistory(
|
||||||
|
session_id="async-session-test-2",
|
||||||
|
collection_name="langchain_cmh_test",
|
||||||
|
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||||
|
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||||
|
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||||
|
setup_mode=SetupMode.ASYNC,
|
||||||
|
)
|
||||||
|
yield history2
|
||||||
|
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("astrapy")
|
@pytest.mark.requires("astrapy")
|
||||||
@ -58,8 +87,12 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
|||||||
assert memory.chat_memory.messages == []
|
assert memory.chat_memory.messages == []
|
||||||
|
|
||||||
# add some messages
|
# add some messages
|
||||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
memory.chat_memory.add_messages(
|
||||||
memory.chat_memory.add_user_message("This is me, the human")
|
[
|
||||||
|
AIMessage(content="This is me, the AI"),
|
||||||
|
HumanMessage(content="This is me, the human"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
expected = [
|
expected = [
|
||||||
@ -74,6 +107,41 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
|||||||
assert memory.chat_memory.messages == []
|
assert memory.chat_memory.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("astrapy")
|
||||||
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
|
async def test_memory_with_message_store_async(
|
||||||
|
async_history1: AstraDBChatMessageHistory,
|
||||||
|
) -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz",
|
||||||
|
chat_memory=async_history1,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await memory.chat_memory.aget_messages() == []
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
await memory.chat_memory.aadd_messages(
|
||||||
|
[
|
||||||
|
AIMessage(content="This is me, the AI"),
|
||||||
|
HumanMessage(content="This is me, the human"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = await memory.chat_memory.aget_messages()
|
||||||
|
expected = [
|
||||||
|
AIMessage(content="This is me, the AI"),
|
||||||
|
HumanMessage(content="This is me, the human"),
|
||||||
|
]
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
# clear the store
|
||||||
|
await memory.chat_memory.aclear()
|
||||||
|
|
||||||
|
assert await memory.chat_memory.aget_messages() == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("astrapy")
|
@pytest.mark.requires("astrapy")
|
||||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
def test_memory_separate_session_ids(
|
def test_memory_separate_session_ids(
|
||||||
@ -91,7 +159,7 @@ def test_memory_separate_session_ids(
|
|||||||
return_messages=True,
|
return_messages=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
memory1.chat_memory.add_ai_message("Just saying.")
|
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
|
||||||
|
|
||||||
assert memory2.chat_memory.messages == []
|
assert memory2.chat_memory.messages == []
|
||||||
|
|
||||||
@ -102,3 +170,33 @@ def test_memory_separate_session_ids(
|
|||||||
memory1.chat_memory.clear()
|
memory1.chat_memory.clear()
|
||||||
|
|
||||||
assert memory1.chat_memory.messages == []
|
assert memory1.chat_memory.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("astrapy")
|
||||||
|
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||||
|
async def test_memory_separate_session_ids_async(
|
||||||
|
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
|
||||||
|
) -> None:
|
||||||
|
"""Test that separate session IDs do not share entries."""
|
||||||
|
memory1 = ConversationBufferMemory(
|
||||||
|
memory_key="mk1",
|
||||||
|
chat_memory=async_history1,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
memory2 = ConversationBufferMemory(
|
||||||
|
memory_key="mk2",
|
||||||
|
chat_memory=async_history2,
|
||||||
|
return_messages=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
|
||||||
|
|
||||||
|
assert await memory2.chat_memory.aget_messages() == []
|
||||||
|
|
||||||
|
await memory2.chat_memory.aclear()
|
||||||
|
|
||||||
|
assert await memory1.chat_memory.aget_messages() != []
|
||||||
|
|
||||||
|
await memory1.chat_memory.aclear()
|
||||||
|
|
||||||
|
assert await memory1.chat_memory.aget_messages() == []
|
||||||
|
Loading…
Reference in New Issue
Block a user