diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index b7cb996cb53..1ae2418401c 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -34,6 +34,18 @@ class SummarizerMixin(BaseModel): chain = LLMChain(llm=self.llm, prompt=self.prompt) return chain.predict(summary=existing_summary, new_lines=new_lines) + async def apredict_new_summary( + self, messages: List[BaseMessage], existing_summary: str + ) -> str: + new_lines = get_buffer_string( + messages, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + + chain = LLMChain(llm=self.llm, prompt=self.prompt) + return await chain.apredict(summary=existing_summary, new_lines=new_lines) + class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): """Conversation summarizer to chat memory.""" diff --git a/libs/langchain/langchain/memory/summary_buffer.py b/libs/langchain/langchain/memory/summary_buffer.py index 389da7e4215..f17f011ab9b 100644 --- a/libs/langchain/langchain/memory/summary_buffer.py +++ b/libs/langchain/langchain/memory/summary_buffer.py @@ -19,6 +19,11 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): """String buffer of memory.""" return self.load_memory_variables({})[self.memory_key] + async def abuffer(self) -> Union[str, List[BaseMessage]]: + """Async memory buffer.""" + memory_variables = await self.aload_memory_variables({}) + return memory_variables[self.memory_key] + @property def memory_variables(self) -> List[str]: """Will always return list of memory variables. @@ -43,6 +48,22 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): ) return {self.memory_key: final_buffer} + async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Asynchronously return key-value pairs given the text input to the chain.""" + buffer = await self.chat_memory.aget_messages() + if self.moving_summary_buffer != "": + first_messages: List[BaseMessage] = [ + self.summary_message_cls(content=self.moving_summary_buffer) + ] + buffer = first_messages + buffer + if self.return_messages: + final_buffer: Any = buffer + else: + final_buffer = get_buffer_string( + buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix + ) + return {self.memory_key: final_buffer} + @root_validator() def validate_prompt_input_variables(cls, values: Dict) -> Dict: """Validate that prompt input variables are consistent.""" @@ -60,6 +81,13 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): super().save_context(inputs, outputs) self.prune() + async def asave_context( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> None: + """Asynchronously save context from this conversation to buffer.""" + await super().asave_context(inputs, outputs) + await self.aprune() + def prune(self) -> None: """Prune buffer if it exceeds max token limit""" buffer = self.chat_memory.messages @@ -73,7 +101,25 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): pruned_memory, self.moving_summary_buffer ) + async def aprune(self) -> None: + """Asynchronously prune buffer if it exceeds max token limit""" + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while curr_buffer_length > self.max_token_limit: + pruned_memory.append(buffer.pop(0)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + self.moving_summary_buffer = await self.apredict_new_summary( + pruned_memory, self.moving_summary_buffer + ) + def clear(self) -> None: """Clear memory contents.""" super().clear() self.moving_summary_buffer = "" + + async def aclear(self) -> None: + """Asynchronously clear memory contents.""" + await super().aclear() + self.moving_summary_buffer = "" diff --git a/libs/langchain/tests/integration_tests/chains/test_memory.py b/libs/langchain/tests/integration_tests/chains/test_memory.py deleted file mode 100644 index 826380b856b..00000000000 --- a/libs/langchain/tests/integration_tests/chains/test_memory.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Test memory functionality.""" - -from langchain.memory.summary_buffer import ConversationSummaryBufferMemory -from tests.unit_tests.llms.fake_llm import FakeLLM - - -def test_summary_buffer_memory_no_buffer_yet() -> None: - """Test ConversationSummaryBufferMemory when no inputs put in buffer yet.""" - memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") - output = memory.load_memory_variables({}) - assert output == {"baz": ""} - - -def test_summary_buffer_memory_buffer_only() -> None: - """Test ConversationSummaryBufferMemory when only buffer.""" - memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") - memory.save_context({"input": "bar"}, {"output": "foo"}) - assert memory.buffer == ["Human: bar\nAI: foo"] - output = memory.load_memory_variables({}) - assert output == {"baz": "Human: bar\nAI: foo"} - - -def test_summary_buffer_memory_summary() -> None: - """Test ConversationSummaryBufferMemory when only buffer.""" - memory = ConversationSummaryBufferMemory( - llm=FakeLLM(), memory_key="baz", max_token_limit=13 - ) - memory.save_context({"input": "bar"}, {"output": "foo"}) - memory.save_context({"input": "bar1"}, {"output": "foo1"}) - assert memory.buffer == ["Human: bar1\nAI: foo1"] - output = memory.load_memory_variables({}) - assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"} diff --git a/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py b/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py new file mode 100644 index 00000000000..f6651c50c23 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_summary_buffer_memory.py @@ -0,0 +1,62 @@ +"""Test memory functionality.""" + +from langchain.memory.summary_buffer import ConversationSummaryBufferMemory +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_summary_buffer_memory_no_buffer_yet() -> None: + """Test ConversationSummaryBufferMemory when no inputs put in buffer yet.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + output = memory.load_memory_variables({}) + assert output == {"baz": ""} + + +async def test_summary_buffer_memory_no_buffer_yet_async() -> None: + """Test ConversationSummaryBufferMemory when no inputs put in buffer yet.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + output = await memory.aload_memory_variables({}) + assert output == {"baz": ""} + + +def test_summary_buffer_memory_buffer_only() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + memory.save_context({"input": "bar"}, {"output": "foo"}) + assert memory.buffer == "Human: bar\nAI: foo" + output = memory.load_memory_variables({}) + assert output == {"baz": "Human: bar\nAI: foo"} + + +async def test_summary_buffer_memory_buffer_only_async() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz") + await memory.asave_context({"input": "bar"}, {"output": "foo"}) + assert memory.buffer == "Human: bar\nAI: foo" + output = await memory.aload_memory_variables({}) + assert output == {"baz": "Human: bar\nAI: foo"} + + +def test_summary_buffer_memory_summary() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + llm = FakeLLM(queries={0: "summary"}, sequential_responses=True) + memory = ConversationSummaryBufferMemory( + llm=llm, memory_key="baz", max_token_limit=5 + ) + memory.save_context({"input": "bar"}, {"output": "foo"}) + memory.save_context({"input": "bar1"}, {"output": "foo1"}) + assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1" + output = memory.load_memory_variables({}) + assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"} + + +async def test_summary_buffer_memory_summary_async() -> None: + """Test ConversationSummaryBufferMemory when only buffer.""" + llm = FakeLLM(queries={0: "summary"}, sequential_responses=True) + memory = ConversationSummaryBufferMemory( + llm=llm, memory_key="baz", max_token_limit=5 + ) + await memory.asave_context({"input": "bar"}, {"output": "foo"}) + await memory.asave_context({"input": "bar1"}, {"output": "foo1"}) + assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1" + output = await memory.aload_memory_variables({}) + assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}