langchain[patch]: add async methods to ConversationSummaryBufferMemory (#20956)

Added asynchronously callable methods according to the
ConversationSummaryBufferMemory API documentation.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Thomas Meike 2024-07-22 15:21:43 +02:00 committed by GitHub
parent cecd875cdc
commit 40c02cedaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 120 additions and 32 deletions

View File

@ -34,6 +34,18 @@ class SummarizerMixin(BaseModel):
chain = LLMChain(llm=self.llm, prompt=self.prompt)
return chain.predict(summary=existing_summary, new_lines=new_lines)
async def apredict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
return await chain.apredict(summary=existing_summary, new_lines=new_lines)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to chat memory."""

View File

@ -19,6 +19,11 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""String buffer of memory."""
return self.load_memory_variables({})[self.memory_key]
async def abuffer(self) -> Union[str, List[BaseMessage]]:
"""Async memory buffer."""
memory_variables = await self.aload_memory_variables({})
return memory_variables[self.memory_key]
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
@ -43,6 +48,22 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
)
return {self.memory_key: final_buffer}
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Asynchronously return key-value pairs given the text input to the chain."""
buffer = await self.chat_memory.aget_messages()
if self.moving_summary_buffer != "":
first_messages: List[BaseMessage] = [
self.summary_message_cls(content=self.moving_summary_buffer)
]
buffer = first_messages + buffer
if self.return_messages:
final_buffer: Any = buffer
else:
final_buffer = get_buffer_string(
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
)
return {self.memory_key: final_buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
@ -60,6 +81,13 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
super().save_context(inputs, outputs)
self.prune()
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Asynchronously save context from this conversation to buffer."""
await super().asave_context(inputs, outputs)
await self.aprune()
def prune(self) -> None:
"""Prune buffer if it exceeds max token limit"""
buffer = self.chat_memory.messages
@ -73,7 +101,25 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
pruned_memory, self.moving_summary_buffer
)
async def aprune(self) -> None:
"""Asynchronously prune buffer if it exceeds max token limit"""
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
self.moving_summary_buffer = await self.apredict_new_summary(
pruned_memory, self.moving_summary_buffer
)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.moving_summary_buffer = ""
async def aclear(self) -> None:
"""Asynchronously clear memory contents."""
await super().aclear()
self.moving_summary_buffer = ""

View File

@ -1,32 +0,0 @@
"""Test memory functionality."""
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_summary_buffer_memory_no_buffer_yet() -> None:
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
output = memory.load_memory_variables({})
assert output == {"baz": ""}
def test_summary_buffer_memory_buffer_only() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == ["Human: bar\nAI: foo"]
output = memory.load_memory_variables({})
assert output == {"baz": "Human: bar\nAI: foo"}
def test_summary_buffer_memory_summary() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(
llm=FakeLLM(), memory_key="baz", max_token_limit=13
)
memory.save_context({"input": "bar"}, {"output": "foo"})
memory.save_context({"input": "bar1"}, {"output": "foo1"})
assert memory.buffer == ["Human: bar1\nAI: foo1"]
output = memory.load_memory_variables({})
assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"}

View File

@ -0,0 +1,62 @@
"""Test memory functionality."""
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_summary_buffer_memory_no_buffer_yet() -> None:
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
output = memory.load_memory_variables({})
assert output == {"baz": ""}
async def test_summary_buffer_memory_no_buffer_yet_async() -> None:
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
output = await memory.aload_memory_variables({})
assert output == {"baz": ""}
def test_summary_buffer_memory_buffer_only() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "Human: bar\nAI: foo"
output = memory.load_memory_variables({})
assert output == {"baz": "Human: bar\nAI: foo"}
async def test_summary_buffer_memory_buffer_only_async() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
await memory.asave_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "Human: bar\nAI: foo"
output = await memory.aload_memory_variables({})
assert output == {"baz": "Human: bar\nAI: foo"}
def test_summary_buffer_memory_summary() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
memory = ConversationSummaryBufferMemory(
llm=llm, memory_key="baz", max_token_limit=5
)
memory.save_context({"input": "bar"}, {"output": "foo"})
memory.save_context({"input": "bar1"}, {"output": "foo1"})
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
output = memory.load_memory_variables({})
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}
async def test_summary_buffer_memory_summary_async() -> None:
"""Test ConversationSummaryBufferMemory when only buffer."""
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
memory = ConversationSummaryBufferMemory(
llm=llm, memory_key="baz", max_token_limit=5
)
await memory.asave_context({"input": "bar"}, {"output": "foo"})
await memory.asave_context({"input": "bar1"}, {"output": "foo1"})
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
output = await memory.aload_memory_variables({})
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}