mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 21:43:44 +00:00
Add async methods to BaseChatMessageHistory and BaseMemory (#16728)
Adds: * async methods to BaseChatMessageHistory * async methods to ChatMessageHistory * async methods to BaseMemory * async methods to BaseChatMemory * async methods to ConversationBufferMemory * tests of ConversationBufferMemory's async methods **Twitter handle:** cbornet_
This commit is contained in:
committed by
GitHub
parent
b3c3b58f2c
commit
2ef69fe11b
@@ -9,16 +9,32 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
Implementations should over-ride the add_messages method to handle bulk addition
|
||||
of messages.
|
||||
Implementations guidelines:
|
||||
|
||||
The default implementation of add_message will correctly call add_messages, so
|
||||
it is not necessary to implement both methods.
|
||||
Implementations are expected to over-ride all or some of the following methods:
|
||||
|
||||
* add_messages: sync variant for bulk addition of messages
|
||||
* aadd_messages: async variant for bulk addition of messages
|
||||
* messages: sync variant for getting messages
|
||||
* aget_messages: async variant for getting messages
|
||||
* clear: sync variant for clearing messages
|
||||
* aclear: async variant for clearing messages
|
||||
|
||||
add_messages contains a default implementation that calls add_message
|
||||
for each message in the sequence. This is provided for backwards compatibility
|
||||
with existing implementations which only had add_message.
|
||||
|
||||
Async variants all have default implementations that call the sync variants.
|
||||
Implementers can choose to over-ride the async implementations to provide
|
||||
truly async implementations.
|
||||
|
||||
Usage guidelines:
|
||||
|
||||
When used for updating history, users should favor usage of `add_messages`
|
||||
over `add_message` or other variants like `add_user_message` and `add_ai_message`
|
||||
@@ -54,7 +70,22 @@ class BaseChatMessageHistory(ABC):
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""A list of Messages stored in-memory."""
|
||||
"""A property or attribute that returns a list of messages.
|
||||
|
||||
In general, getting the messages may involve IO to the underlying
|
||||
persistence layer, so this operation is expected to incur some
|
||||
latency.
|
||||
"""
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
"""Async version of getting messages.
|
||||
|
||||
Can over-ride this method to provide an efficient async implementation.
|
||||
|
||||
In general, fetching messages may involve IO to the underlying
|
||||
persistence layer.
|
||||
"""
|
||||
return await run_in_executor(None, lambda: self.messages)
|
||||
|
||||
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
@@ -98,7 +129,7 @@ class BaseChatMessageHistory(ABC):
|
||||
"""
|
||||
if type(self).add_messages != BaseChatMessageHistory.add_messages:
|
||||
# This means that the sub-class has implemented an efficient add_messages
|
||||
# method, so we should usage of add_message to that.
|
||||
# method, so we should use it.
|
||||
self.add_messages([message])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -118,10 +149,22 @@ class BaseChatMessageHistory(ABC):
|
||||
for message in messages:
|
||||
self.add_message(message)
|
||||
|
||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
"""Add a list of messages.
|
||||
|
||||
Args:
|
||||
messages: A list of BaseMessage objects to store.
|
||||
"""
|
||||
await run_in_executor(None, self.add_messages, messages)
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
await run_in_executor(None, self.clear)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the chat history."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
@@ -50,10 +51,24 @@ class BaseMemory(Serializable, ABC):
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
await run_in_executor(None, self.save_context, inputs, outputs)
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await run_in_executor(None, self.clear)
|
||||
|
@@ -66,3 +66,36 @@ def test_bulk_message_implementation_only() -> None:
|
||||
assert len(store) == 4
|
||||
assert store[2] == HumanMessage(content="Hello")
|
||||
assert store[3] == HumanMessage(content="World")
|
||||
|
||||
|
||||
async def test_async_interface() -> None:
|
||||
"""Test async interface for BaseChatMessageHistory."""
|
||||
|
||||
class BulkAddHistory(BaseChatMessageHistory):
|
||||
def __init__(self) -> None:
|
||||
self.messages = []
|
||||
|
||||
def add_messages(self, message: Sequence[BaseMessage]) -> None:
|
||||
"""Add a message to the store."""
|
||||
self.messages.extend(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the store."""
|
||||
self.messages.clear()
|
||||
|
||||
chat_history = BulkAddHistory()
|
||||
await chat_history.aadd_messages(
|
||||
[HumanMessage(content="Hello"), HumanMessage(content="World")]
|
||||
)
|
||||
assert await chat_history.aget_messages() == [
|
||||
HumanMessage(content="Hello"),
|
||||
HumanMessage(content="World"),
|
||||
]
|
||||
await chat_history.aadd_messages([HumanMessage(content="!")])
|
||||
assert await chat_history.aget_messages() == [
|
||||
HumanMessage(content="Hello"),
|
||||
HumanMessage(content="World"),
|
||||
HumanMessage(content="!"),
|
||||
]
|
||||
await chat_history.aclear()
|
||||
assert await chat_history.aget_messages() == []
|
||||
|
Reference in New Issue
Block a user