mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 02:50:31 +00:00
Add async methods to BaseChatMessageHistory and BaseMemory (#16728)
Adds: * async methods to BaseChatMessageHistory * async methods to ChatMessageHistory * async methods to BaseMemory * async methods to BaseChatMemory * async methods to ConversationBufferMemory * tests of ConversationBufferMemory's async methods **Twitter handle:** cbornet_
This commit is contained in:
committed by
GitHub
parent
b3c3b58f2c
commit
2ef69fe11b
@@ -19,20 +19,40 @@ class ConversationBufferMemory(BaseChatMemory):
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
async def abuffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return (
|
||||
await self.abuffer_as_messages()
|
||||
if self.return_messages
|
||||
else await self.abuffer_as_str()
|
||||
)
|
||||
|
||||
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
return self._buffer_as_str(self.chat_memory.messages)
|
||||
|
||||
async def abuffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
messages = await self.chat_memory.aget_messages()
|
||||
return self._buffer_as_str(messages)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return await self.chat_memory.aget_messages()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
@@ -45,6 +65,11 @@ class ConversationBufferMemory(BaseChatMemory):
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.abuffer()
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
|
||||
class ConversationStringBufferMemory(BaseMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
@@ -77,6 +102,10 @@ class ConversationStringBufferMemory(BaseMemory):
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return self.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
if self.input_key is None:
|
||||
@@ -93,6 +122,15 @@ class ConversationStringBufferMemory(BaseMemory):
|
||||
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
||||
self.buffer += "\n" + "\n".join([human, ai])
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
return self.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.buffer = ""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
self.clear()
|
||||
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
@@ -35,9 +36,23 @@ class BaseChatMemory(BaseMemory, ABC):
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_user_message(input_str)
|
||||
self.chat_memory.add_ai_message(output_str)
|
||||
self.chat_memory.add_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
await self.chat_memory.aadd_messages(
|
||||
[HumanMessage(content=input_str), AIMessage(content=output_str)]
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await self.chat_memory.aclear()
|
||||
|
@@ -14,14 +14,22 @@ def test_memory_ai_prefix() -> None:
|
||||
"""Test that ai_prefix in the memory component works."""
|
||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "Human: bar\nAssistant: foo"
|
||||
assert memory.load_memory_variables({}) == {"foo": "Human: bar\nAssistant: foo"}
|
||||
|
||||
|
||||
def test_memory_human_prefix() -> None:
|
||||
"""Test that human_prefix in the memory component works."""
|
||||
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "Friend: bar\nAI: foo"
|
||||
assert memory.load_memory_variables({}) == {"foo": "Friend: bar\nAI: foo"}
|
||||
|
||||
|
||||
async def test_memory_async() -> None:
|
||||
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
|
||||
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||
assert await memory.aload_memory_variables({}) == {
|
||||
"foo": "Human: bar\nAssistant: foo"
|
||||
}
|
||||
|
||||
|
||||
def test_conversation_chain_works() -> None:
|
||||
@@ -100,3 +108,23 @@ def test_clearing_conversation_memory(memory: BaseMemory) -> None:
|
||||
|
||||
memory.clear()
|
||||
assert memory.load_memory_variables({}) == {"baz": ""}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memory",
|
||||
[
|
||||
ConversationBufferMemory(memory_key="baz"),
|
||||
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
|
||||
ConversationBufferWindowMemory(memory_key="baz"),
|
||||
],
|
||||
)
|
||||
async def test_clearing_conversation_memory_async(memory: BaseMemory) -> None:
|
||||
"""Test clearing the conversation memory."""
|
||||
# This is a good input because the input is not the same as baz.
|
||||
good_inputs = {"foo": "bar", "baz": "foo"}
|
||||
# This is a good output because there is one variable.
|
||||
good_outputs = {"bar": "foo"}
|
||||
await memory.asave_context(good_inputs, good_outputs)
|
||||
|
||||
await memory.aclear()
|
||||
assert await memory.aload_memory_variables({}) == {"baz": ""}
|
||||
|
Reference in New Issue
Block a user