mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 03:02:35 +00:00
langchain[patch]: add async methods to ConversationSummaryBufferMemory (#20956)
Added asynchronously callable methods according to the ConversationSummaryBufferMemory API documentation. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
cecd875cdc
commit
40c02cedaf
@ -34,6 +34,18 @@ class SummarizerMixin(BaseModel):
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
async def apredict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return await chain.apredict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
|
||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Conversation summarizer to chat memory."""
|
||||
|
@ -19,6 +19,11 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""String buffer of memory."""
|
||||
return self.load_memory_variables({})[self.memory_key]
|
||||
|
||||
async def abuffer(self) -> Union[str, List[BaseMessage]]:
|
||||
"""Async memory buffer."""
|
||||
memory_variables = await self.aload_memory_variables({})
|
||||
return memory_variables[self.memory_key]
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
@ -43,6 +48,22 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Asynchronously return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.chat_memory.aget_messages()
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: List[BaseMessage] = [
|
||||
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
@ -60,6 +81,13 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
super().save_context(inputs, outputs)
|
||||
self.prune()
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Asynchronously save context from this conversation to buffer."""
|
||||
await super().asave_context(inputs, outputs)
|
||||
await self.aprune()
|
||||
|
||||
def prune(self) -> None:
|
||||
"""Prune buffer if it exceeds max token limit"""
|
||||
buffer = self.chat_memory.messages
|
||||
@ -73,7 +101,25 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
pruned_memory, self.moving_summary_buffer
|
||||
)
|
||||
|
||||
async def aprune(self) -> None:
|
||||
"""Asynchronously prune buffer if it exceeds max token limit"""
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = await self.apredict_new_summary(
|
||||
pruned_memory, self.moving_summary_buffer
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.moving_summary_buffer = ""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Asynchronously clear memory contents."""
|
||||
await super().aclear()
|
||||
self.moving_summary_buffer = ""
|
||||
|
@ -1,32 +0,0 @@
|
||||
"""Test memory functionality."""
|
||||
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_summary_buffer_memory_no_buffer_yet() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": ""}
|
||||
|
||||
|
||||
def test_summary_buffer_memory_buffer_only() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == ["Human: bar\nAI: foo"]
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": "Human: bar\nAI: foo"}
|
||||
|
||||
|
||||
def test_summary_buffer_memory_summary() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
memory = ConversationSummaryBufferMemory(
|
||||
llm=FakeLLM(), memory_key="baz", max_token_limit=13
|
||||
)
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
memory.save_context({"input": "bar1"}, {"output": "foo1"})
|
||||
assert memory.buffer == ["Human: bar1\nAI: foo1"]
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"}
|
@ -0,0 +1,62 @@
|
||||
"""Test memory functionality."""
|
||||
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_summary_buffer_memory_no_buffer_yet() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": ""}
|
||||
|
||||
|
||||
async def test_summary_buffer_memory_no_buffer_yet_async() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
output = await memory.aload_memory_variables({})
|
||||
assert output == {"baz": ""}
|
||||
|
||||
|
||||
def test_summary_buffer_memory_buffer_only() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "Human: bar\nAI: foo"
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": "Human: bar\nAI: foo"}
|
||||
|
||||
|
||||
async def test_summary_buffer_memory_buffer_only_async() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||
assert memory.buffer == "Human: bar\nAI: foo"
|
||||
output = await memory.aload_memory_variables({})
|
||||
assert output == {"baz": "Human: bar\nAI: foo"}
|
||||
|
||||
|
||||
def test_summary_buffer_memory_summary() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
|
||||
memory = ConversationSummaryBufferMemory(
|
||||
llm=llm, memory_key="baz", max_token_limit=5
|
||||
)
|
||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||
memory.save_context({"input": "bar1"}, {"output": "foo1"})
|
||||
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
|
||||
output = memory.load_memory_variables({})
|
||||
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}
|
||||
|
||||
|
||||
async def test_summary_buffer_memory_summary_async() -> None:
|
||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
|
||||
memory = ConversationSummaryBufferMemory(
|
||||
llm=llm, memory_key="baz", max_token_limit=5
|
||||
)
|
||||
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||
await memory.asave_context({"input": "bar1"}, {"output": "foo1"})
|
||||
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
|
||||
output = await memory.aload_memory_variables({})
|
||||
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}
|
Loading…
Reference in New Issue
Block a user