From 2ef69fe11bb33dfe1419741d477e2dc0b82a5d37 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 5 Feb 2024 10:20:28 -0800 Subject: [PATCH] Add async methods to BaseChatMessageHistory and BaseMemory (#16728) Adds: * async methods to BaseChatMessageHistory * async methods to ChatMessageHistory * async methods to BaseMemory * async methods to BaseChatMemory * async methods to ConversationBufferMemory * tests of ConversationBufferMemory's async methods **Twitter handle:** cbornet_ --- .../chat_message_histories/in_memory.py | 12 +++- libs/core/langchain_core/chat_history.py | 55 +++++++++++++++++-- libs/core/langchain_core/memory.py | 15 +++++ .../chat_history/test_chat_history.py | 33 +++++++++++ libs/langchain/langchain/memory/buffer.py | 46 ++++++++++++++-- .../langchain/langchain/memory/chat_memory.py | 19 ++++++- .../unit_tests/chains/test_conversation.py | 32 ++++++++++- 7 files changed, 197 insertions(+), 15 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/in_memory.py b/libs/community/langchain_community/chat_message_histories/in_memory.py index 8c76e850dd9..fe6c6406524 100644 --- a/libs/community/langchain_community/chat_message_histories/in_memory.py +++ b/libs/community/langchain_community/chat_message_histories/in_memory.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Sequence from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage @@ -13,9 +13,19 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel): messages: List[BaseMessage] = Field(default_factory=list) + async def aget_messages(self) -> List[BaseMessage]: + return self.messages + def add_message(self, message: BaseMessage) -> None: """Add a self-created message to the store""" self.messages.append(message) + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Add messages to the store""" + self.add_messages(messages) + def clear(self) -> None: self.messages = [] + + async def aclear(self) -> None: + self.clear() diff --git a/libs/core/langchain_core/chat_history.py b/libs/core/langchain_core/chat_history.py index e558c1479ed..1042e2e8ef8 100644 --- a/libs/core/langchain_core/chat_history.py +++ b/libs/core/langchain_core/chat_history.py @@ -9,16 +9,32 @@ from langchain_core.messages import ( HumanMessage, get_buffer_string, ) +from langchain_core.runnables import run_in_executor class BaseChatMessageHistory(ABC): """Abstract base class for storing chat message history. - Implementations should over-ride the add_messages method to handle bulk addition - of messages. + Implementations guidelines: - The default implementation of add_message will correctly call add_messages, so - it is not necessary to implement both methods. + Implementations are expected to over-ride all or some of the following methods: + + * add_messages: sync variant for bulk addition of messages + * aadd_messages: async variant for bulk addition of messages + * messages: sync variant for getting messages + * aget_messages: async variant for getting messages + * clear: sync variant for clearing messages + * aclear: async variant for clearing messages + + add_messages contains a default implementation that calls add_message + for each message in the sequence. This is provided for backwards compatibility + with existing implementations which only had add_message. + + Async variants all have default implementations that call the sync variants. + Implementers can choose to over-ride the async implementations to provide + truly async implementations. + + Usage guidelines: When used for updating history, users should favor usage of `add_messages` over `add_message` or other variants like `add_user_message` and `add_ai_message` @@ -54,7 +70,22 @@ class BaseChatMessageHistory(ABC): """ messages: List[BaseMessage] - """A list of Messages stored in-memory.""" + """A property or attribute that returns a list of messages. + + In general, getting the messages may involve IO to the underlying + persistence layer, so this operation is expected to incur some + latency. + """ + + async def aget_messages(self) -> List[BaseMessage]: + """Async version of getting messages. + + Can over-ride this method to provide an efficient async implementation. + + In general, fetching messages may involve IO to the underlying + persistence layer. + """ + return await run_in_executor(None, lambda: self.messages) def add_user_message(self, message: Union[HumanMessage, str]) -> None: """Convenience method for adding a human message string to the store. @@ -98,7 +129,7 @@ class BaseChatMessageHistory(ABC): """ if type(self).add_messages != BaseChatMessageHistory.add_messages: # This means that the sub-class has implemented an efficient add_messages - # method, so we should usage of add_message to that. + # method, so we should use it. self.add_messages([message]) else: raise NotImplementedError( @@ -118,10 +149,22 @@ class BaseChatMessageHistory(ABC): for message in messages: self.add_message(message) + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Add a list of messages. + + Args: + messages: A list of BaseMessage objects to store. + """ + await run_in_executor(None, self.add_messages, messages) + @abstractmethod def clear(self) -> None: """Remove all messages from the store""" + async def aclear(self) -> None: + """Remove all messages from the store""" + await run_in_executor(None, self.clear) + def __str__(self) -> str: """Return a string representation of the chat history.""" return get_buffer_string(self.messages) diff --git a/libs/core/langchain_core/memory.py b/libs/core/langchain_core/memory.py index 0b362661cfd..ad61e90fdad 100644 --- a/libs/core/langchain_core/memory.py +++ b/libs/core/langchain_core/memory.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List from langchain_core.load.serializable import Serializable +from langchain_core.runnables import run_in_executor class BaseMemory(Serializable, ABC): @@ -50,10 +51,24 @@ class BaseMemory(Serializable, ABC): def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return key-value pairs given the text input to the chain.""" + async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return key-value pairs given the text input to the chain.""" + return await run_in_executor(None, self.load_memory_variables, inputs) + @abstractmethod def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save the context of this chain run to memory.""" + async def asave_context( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> None: + """Save the context of this chain run to memory.""" + await run_in_executor(None, self.save_context, inputs, outputs) + @abstractmethod def clear(self) -> None: """Clear memory contents.""" + + async def aclear(self) -> None: + """Clear memory contents.""" + await run_in_executor(None, self.clear) diff --git a/libs/core/tests/unit_tests/chat_history/test_chat_history.py b/libs/core/tests/unit_tests/chat_history/test_chat_history.py index 0f6b696e8fb..e7d2b724f7e 100644 --- a/libs/core/tests/unit_tests/chat_history/test_chat_history.py +++ b/libs/core/tests/unit_tests/chat_history/test_chat_history.py @@ -66,3 +66,36 @@ def test_bulk_message_implementation_only() -> None: assert len(store) == 4 assert store[2] == HumanMessage(content="Hello") assert store[3] == HumanMessage(content="World") + + +async def test_async_interface() -> None: + """Test async interface for BaseChatMessageHistory.""" + + class BulkAddHistory(BaseChatMessageHistory): + def __init__(self) -> None: + self.messages = [] + + def add_messages(self, message: Sequence[BaseMessage]) -> None: + """Add a message to the store.""" + self.messages.extend(message) + + def clear(self) -> None: + """Clear the store.""" + self.messages.clear() + + chat_history = BulkAddHistory() + await chat_history.aadd_messages( + [HumanMessage(content="Hello"), HumanMessage(content="World")] + ) + assert await chat_history.aget_messages() == [ + HumanMessage(content="Hello"), + HumanMessage(content="World"), + ] + await chat_history.aadd_messages([HumanMessage(content="!")]) + assert await chat_history.aget_messages() == [ + HumanMessage(content="Hello"), + HumanMessage(content="World"), + HumanMessage(content="!"), + ] + await chat_history.aclear() + assert await chat_history.aget_messages() == [] diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index c8b819fb754..3cc9b162dab 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -19,20 +19,40 @@ class ConversationBufferMemory(BaseChatMemory): """String buffer of memory.""" return self.buffer_as_messages if self.return_messages else self.buffer_as_str - @property - def buffer_as_str(self) -> str: - """Exposes the buffer as a string in case return_messages is True.""" + async def abuffer(self) -> Any: + """String buffer of memory.""" + return ( + await self.abuffer_as_messages() + if self.return_messages + else await self.abuffer_as_str() + ) + + def _buffer_as_str(self, messages: List[BaseMessage]) -> str: return get_buffer_string( - self.chat_memory.messages, + messages, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) + @property + def buffer_as_str(self) -> str: + """Exposes the buffer as a string in case return_messages is True.""" + return self._buffer_as_str(self.chat_memory.messages) + + async def abuffer_as_str(self) -> str: + """Exposes the buffer as a string in case return_messages is True.""" + messages = await self.chat_memory.aget_messages() + return self._buffer_as_str(messages) + @property def buffer_as_messages(self) -> List[BaseMessage]: """Exposes the buffer as a list of messages in case return_messages is False.""" return self.chat_memory.messages + async def abuffer_as_messages(self) -> List[BaseMessage]: + """Exposes the buffer as a list of messages in case return_messages is False.""" + return await self.chat_memory.aget_messages() + @property def memory_variables(self) -> List[str]: """Will always return list of memory variables. @@ -45,6 +65,11 @@ class ConversationBufferMemory(BaseChatMemory): """Return history buffer.""" return {self.memory_key: self.buffer} + async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return key-value pairs given the text input to the chain.""" + buffer = await self.abuffer() + return {self.memory_key: buffer} + class ConversationStringBufferMemory(BaseMemory): """Buffer for storing conversation memory.""" @@ -77,6 +102,10 @@ class ConversationStringBufferMemory(BaseMemory): """Return history buffer.""" return {self.memory_key: self.buffer} + async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Return history buffer.""" + return self.load_memory_variables(inputs) + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" if self.input_key is None: @@ -93,6 +122,15 @@ class ConversationStringBufferMemory(BaseMemory): ai = f"{self.ai_prefix}: " + outputs[output_key] self.buffer += "\n" + "\n".join([human, ai]) + async def asave_context( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> None: + """Save context from this conversation to buffer.""" + return self.save_context(inputs, outputs) + def clear(self) -> None: """Clear memory contents.""" self.buffer = "" + + async def aclear(self) -> None: + self.clear() diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index 7808264209b..ad030c3f71a 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Tuple from langchain_community.chat_message_histories.in_memory import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.memory import BaseMemory +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.pydantic_v1 import Field from langchain.memory.utils import get_prompt_input_key @@ -35,9 +36,23 @@ class BaseChatMemory(BaseMemory, ABC): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" input_str, output_str = self._get_input_output(inputs, outputs) - self.chat_memory.add_user_message(input_str) - self.chat_memory.add_ai_message(output_str) + self.chat_memory.add_messages( + [HumanMessage(content=input_str), AIMessage(content=output_str)] + ) + + async def asave_context( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> None: + """Save context from this conversation to buffer.""" + input_str, output_str = self._get_input_output(inputs, outputs) + await self.chat_memory.aadd_messages( + [HumanMessage(content=input_str), AIMessage(content=output_str)] + ) def clear(self) -> None: """Clear memory contents.""" self.chat_memory.clear() + + async def aclear(self) -> None: + """Clear memory contents.""" + await self.chat_memory.aclear() diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index 86ecd647e8d..d00b1e4bc6e 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -14,14 +14,22 @@ def test_memory_ai_prefix() -> None: """Test that ai_prefix in the memory component works.""" memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") memory.save_context({"input": "bar"}, {"output": "foo"}) - assert memory.buffer == "Human: bar\nAssistant: foo" + assert memory.load_memory_variables({}) == {"foo": "Human: bar\nAssistant: foo"} def test_memory_human_prefix() -> None: """Test that human_prefix in the memory component works.""" memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend") memory.save_context({"input": "bar"}, {"output": "foo"}) - assert memory.buffer == "Friend: bar\nAI: foo" + assert memory.load_memory_variables({}) == {"foo": "Friend: bar\nAI: foo"} + + +async def test_memory_async() -> None: + memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant") + await memory.asave_context({"input": "bar"}, {"output": "foo"}) + assert await memory.aload_memory_variables({}) == { + "foo": "Human: bar\nAssistant: foo" + } def test_conversation_chain_works() -> None: @@ -100,3 +108,23 @@ def test_clearing_conversation_memory(memory: BaseMemory) -> None: memory.clear() assert memory.load_memory_variables({}) == {"baz": ""} + + +@pytest.mark.parametrize( + "memory", + [ + ConversationBufferMemory(memory_key="baz"), + ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"), + ConversationBufferWindowMemory(memory_key="baz"), + ], +) +async def test_clearing_conversation_memory_async(memory: BaseMemory) -> None: + """Test clearing the conversation memory.""" + # This is a good input because the input is not the same as baz. + good_inputs = {"foo": "bar", "baz": "foo"} + # This is a good output because there is one variable. + good_outputs = {"bar": "foo"} + await memory.asave_context(good_inputs, good_outputs) + + await memory.aclear() + assert await memory.aload_memory_variables({}) == {"baz": ""}