mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Add async methods to BaseChatMessageHistory and BaseMemory (#16728)
Adds: * async methods to BaseChatMessageHistory * async methods to ChatMessageHistory * async methods to BaseMemory * async methods to BaseChatMemory * async methods to ConversationBufferMemory * tests of ConversationBufferMemory's async methods **Twitter handle:** cbornet_
This commit is contained in:
parent
b3c3b58f2c
commit
2ef69fe11b
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Sequence
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
@ -13,9 +13,19 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
|||||||
|
|
||||||
messages: List[BaseMessage] = Field(default_factory=list)
|
messages: List[BaseMessage] = Field(default_factory=list)
|
||||||
|
|
||||||
|
async def aget_messages(self) -> List[BaseMessage]:
|
||||||
|
return self.messages
|
||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
"""Add a self-created message to the store"""
|
"""Add a self-created message to the store"""
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
|
|
||||||
|
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Add messages to the store"""
|
||||||
|
self.add_messages(messages)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
self.clear()
|
||||||
|
@ -9,16 +9,32 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class BaseChatMessageHistory(ABC):
|
class BaseChatMessageHistory(ABC):
|
||||||
"""Abstract base class for storing chat message history.
|
"""Abstract base class for storing chat message history.
|
||||||
|
|
||||||
Implementations should over-ride the add_messages method to handle bulk addition
|
Implementations guidelines:
|
||||||
of messages.
|
|
||||||
|
|
||||||
The default implementation of add_message will correctly call add_messages, so
|
Implementations are expected to over-ride all or some of the following methods:
|
||||||
it is not necessary to implement both methods.
|
|
||||||
|
* add_messages: sync variant for bulk addition of messages
|
||||||
|
* aadd_messages: async variant for bulk addition of messages
|
||||||
|
* messages: sync variant for getting messages
|
||||||
|
* aget_messages: async variant for getting messages
|
||||||
|
* clear: sync variant for clearing messages
|
||||||
|
* aclear: async variant for clearing messages
|
||||||
|
|
||||||
|
add_messages contains a default implementation that calls add_message
|
||||||
|
for each message in the sequence. This is provided for backwards compatibility
|
||||||
|
with existing implementations which only had add_message.
|
||||||
|
|
||||||
|
Async variants all have default implementations that call the sync variants.
|
||||||
|
Implementers can choose to over-ride the async implementations to provide
|
||||||
|
truly async implementations.
|
||||||
|
|
||||||
|
Usage guidelines:
|
||||||
|
|
||||||
When used for updating history, users should favor usage of `add_messages`
|
When used for updating history, users should favor usage of `add_messages`
|
||||||
over `add_message` or other variants like `add_user_message` and `add_ai_message`
|
over `add_message` or other variants like `add_user_message` and `add_ai_message`
|
||||||
@ -54,7 +70,22 @@ class BaseChatMessageHistory(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
messages: List[BaseMessage]
|
messages: List[BaseMessage]
|
||||||
"""A list of Messages stored in-memory."""
|
"""A property or attribute that returns a list of messages.
|
||||||
|
|
||||||
|
In general, getting the messages may involve IO to the underlying
|
||||||
|
persistence layer, so this operation is expected to incur some
|
||||||
|
latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def aget_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Async version of getting messages.
|
||||||
|
|
||||||
|
Can over-ride this method to provide an efficient async implementation.
|
||||||
|
|
||||||
|
In general, fetching messages may involve IO to the underlying
|
||||||
|
persistence layer.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, lambda: self.messages)
|
||||||
|
|
||||||
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
|
||||||
"""Convenience method for adding a human message string to the store.
|
"""Convenience method for adding a human message string to the store.
|
||||||
@ -98,7 +129,7 @@ class BaseChatMessageHistory(ABC):
|
|||||||
"""
|
"""
|
||||||
if type(self).add_messages != BaseChatMessageHistory.add_messages:
|
if type(self).add_messages != BaseChatMessageHistory.add_messages:
|
||||||
# This means that the sub-class has implemented an efficient add_messages
|
# This means that the sub-class has implemented an efficient add_messages
|
||||||
# method, so we should usage of add_message to that.
|
# method, so we should use it.
|
||||||
self.add_messages([message])
|
self.add_messages([message])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -118,10 +149,22 @@ class BaseChatMessageHistory(ABC):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
self.add_message(message)
|
self.add_message(message)
|
||||||
|
|
||||||
|
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Add a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: A list of BaseMessage objects to store.
|
||||||
|
"""
|
||||||
|
await run_in_executor(None, self.add_messages, messages)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Remove all messages from the store"""
|
"""Remove all messages from the store"""
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
"""Remove all messages from the store"""
|
||||||
|
await run_in_executor(None, self.clear)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return a string representation of the chat history."""
|
"""Return a string representation of the chat history."""
|
||||||
return get_buffer_string(self.messages)
|
return get_buffer_string(self.messages)
|
||||||
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(Serializable, ABC):
|
class BaseMemory(Serializable, ABC):
|
||||||
@ -50,10 +51,24 @@ class BaseMemory(Serializable, ABC):
|
|||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
|
||||||
|
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
return await run_in_executor(None, self.load_memory_variables, inputs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save the context of this chain run to memory."""
|
"""Save the context of this chain run to memory."""
|
||||||
|
|
||||||
|
async def asave_context(
|
||||||
|
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Save the context of this chain run to memory."""
|
||||||
|
await run_in_executor(None, self.save_context, inputs, outputs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear memory contents."""
|
"""Clear memory contents."""
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
await run_in_executor(None, self.clear)
|
||||||
|
@ -66,3 +66,36 @@ def test_bulk_message_implementation_only() -> None:
|
|||||||
assert len(store) == 4
|
assert len(store) == 4
|
||||||
assert store[2] == HumanMessage(content="Hello")
|
assert store[2] == HumanMessage(content="Hello")
|
||||||
assert store[3] == HumanMessage(content="World")
|
assert store[3] == HumanMessage(content="World")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_interface() -> None:
|
||||||
|
"""Test async interface for BaseChatMessageHistory."""
|
||||||
|
|
||||||
|
class BulkAddHistory(BaseChatMessageHistory):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def add_messages(self, message: Sequence[BaseMessage]) -> None:
|
||||||
|
"""Add a message to the store."""
|
||||||
|
self.messages.extend(message)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear the store."""
|
||||||
|
self.messages.clear()
|
||||||
|
|
||||||
|
chat_history = BulkAddHistory()
|
||||||
|
await chat_history.aadd_messages(
|
||||||
|
[HumanMessage(content="Hello"), HumanMessage(content="World")]
|
||||||
|
)
|
||||||
|
assert await chat_history.aget_messages() == [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
HumanMessage(content="World"),
|
||||||
|
]
|
||||||
|
await chat_history.aadd_messages([HumanMessage(content="!")])
|
||||||
|
assert await chat_history.aget_messages() == [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
HumanMessage(content="World"),
|
||||||
|
HumanMessage(content="!"),
|
||||||
|
]
|
||||||
|
await chat_history.aclear()
|
||||||
|
assert await chat_history.aget_messages() == []
|
||||||
|
@ -19,20 +19,40 @@ class ConversationBufferMemory(BaseChatMemory):
|
|||||||
"""String buffer of memory."""
|
"""String buffer of memory."""
|
||||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||||
|
|
||||||
@property
|
async def abuffer(self) -> Any:
|
||||||
def buffer_as_str(self) -> str:
|
"""String buffer of memory."""
|
||||||
"""Exposes the buffer as a string in case return_messages is True."""
|
return (
|
||||||
|
await self.abuffer_as_messages()
|
||||||
|
if self.return_messages
|
||||||
|
else await self.abuffer_as_str()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
||||||
return get_buffer_string(
|
return get_buffer_string(
|
||||||
self.chat_memory.messages,
|
messages,
|
||||||
human_prefix=self.human_prefix,
|
human_prefix=self.human_prefix,
|
||||||
ai_prefix=self.ai_prefix,
|
ai_prefix=self.ai_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_as_str(self) -> str:
|
||||||
|
"""Exposes the buffer as a string in case return_messages is True."""
|
||||||
|
return self._buffer_as_str(self.chat_memory.messages)
|
||||||
|
|
||||||
|
async def abuffer_as_str(self) -> str:
|
||||||
|
"""Exposes the buffer as a string in case return_messages is True."""
|
||||||
|
messages = await self.chat_memory.aget_messages()
|
||||||
|
return self._buffer_as_str(messages)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||||
return self.chat_memory.messages
|
return self.chat_memory.messages
|
||||||
|
|
||||||
|
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||||
|
return await self.chat_memory.aget_messages()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> List[str]:
|
||||||
"""Will always return list of memory variables.
|
"""Will always return list of memory variables.
|
||||||
@ -45,6 +65,11 @@ class ConversationBufferMemory(BaseChatMemory):
|
|||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
|
||||||
|
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
buffer = await self.abuffer()
|
||||||
|
return {self.memory_key: buffer}
|
||||||
|
|
||||||
|
|
||||||
class ConversationStringBufferMemory(BaseMemory):
|
class ConversationStringBufferMemory(BaseMemory):
|
||||||
"""Buffer for storing conversation memory."""
|
"""Buffer for storing conversation memory."""
|
||||||
@ -77,6 +102,10 @@ class ConversationStringBufferMemory(BaseMemory):
|
|||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
|
||||||
|
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Return history buffer."""
|
||||||
|
return self.load_memory_variables(inputs)
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save context from this conversation to buffer."""
|
"""Save context from this conversation to buffer."""
|
||||||
if self.input_key is None:
|
if self.input_key is None:
|
||||||
@ -93,6 +122,15 @@ class ConversationStringBufferMemory(BaseMemory):
|
|||||||
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
||||||
self.buffer += "\n" + "\n".join([human, ai])
|
self.buffer += "\n" + "\n".join([human, ai])
|
||||||
|
|
||||||
|
async def asave_context(
|
||||||
|
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Save context from this conversation to buffer."""
|
||||||
|
return self.save_context(inputs, outputs)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear memory contents."""
|
"""Clear memory contents."""
|
||||||
self.buffer = ""
|
self.buffer = ""
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
self.clear()
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
|
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
@ -35,9 +36,23 @@ class BaseChatMemory(BaseMemory, ABC):
|
|||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save context from this conversation to buffer."""
|
"""Save context from this conversation to buffer."""
|
||||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||||
self.chat_memory.add_user_message(input_str)
|
self.chat_memory.add_messages(
|
||||||
self.chat_memory.add_ai_message(output_str)
|
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def asave_context(
|
||||||
|
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Save context from this conversation to buffer."""
|
||||||
|
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||||
|
await self.chat_memory.aadd_messages(
|
||||||
|
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||||
|
)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear memory contents."""
|
"""Clear memory contents."""
|
||||||
self.chat_memory.clear()
|
self.chat_memory.clear()
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
await self.chat_memory.aclear()
|
||||||
|
@ -14,14 +14,22 @@ def test_memory_ai_prefix() -> None:
|
|||||||
"""Test that ai_prefix in the memory component works."""
|
"""Test that ai_prefix in the memory component works."""
|
||||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||||
assert memory.buffer == "Human: bar\nAssistant: foo"
|
assert memory.load_memory_variables({}) == {"foo": "Human: bar\nAssistant: foo"}
|
||||||
|
|
||||||
|
|
||||||
def test_memory_human_prefix() -> None:
|
def test_memory_human_prefix() -> None:
|
||||||
"""Test that human_prefix in the memory component works."""
|
"""Test that human_prefix in the memory component works."""
|
||||||
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
|
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
|
||||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||||
assert memory.buffer == "Friend: bar\nAI: foo"
|
assert memory.load_memory_variables({}) == {"foo": "Friend: bar\nAI: foo"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_memory_async() -> None:
|
||||||
|
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||||
|
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||||
|
assert await memory.aload_memory_variables({}) == {
|
||||||
|
"foo": "Human: bar\nAssistant: foo"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_chain_works() -> None:
|
def test_conversation_chain_works() -> None:
|
||||||
@ -100,3 +108,23 @@ def test_clearing_conversation_memory(memory: BaseMemory) -> None:
|
|||||||
|
|
||||||
memory.clear()
|
memory.clear()
|
||||||
assert memory.load_memory_variables({}) == {"baz": ""}
|
assert memory.load_memory_variables({}) == {"baz": ""}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"memory",
|
||||||
|
[
|
||||||
|
ConversationBufferMemory(memory_key="baz"),
|
||||||
|
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
|
||||||
|
ConversationBufferWindowMemory(memory_key="baz"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_clearing_conversation_memory_async(memory: BaseMemory) -> None:
|
||||||
|
"""Test clearing the conversation memory."""
|
||||||
|
# This is a good input because the input is not the same as baz.
|
||||||
|
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||||
|
# This is a good output because there is one variable.
|
||||||
|
good_outputs = {"bar": "foo"}
|
||||||
|
await memory.asave_context(good_inputs, good_outputs)
|
||||||
|
|
||||||
|
await memory.aclear()
|
||||||
|
assert await memory.aload_memory_variables({}) == {"baz": ""}
|
||||||
|
Loading…
Reference in New Issue
Block a user