mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
langchain[patch]: add async methods to ConversationSummaryBufferMemory (#20956)
Added asynchronously callable methods according to the ConversationSummaryBufferMemory API documentation. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
cecd875cdc
commit
40c02cedaf
@ -34,6 +34,18 @@ class SummarizerMixin(BaseModel):
|
|||||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||||
|
|
||||||
|
async def apredict_new_summary(
|
||||||
|
self, messages: List[BaseMessage], existing_summary: str
|
||||||
|
) -> str:
|
||||||
|
new_lines = get_buffer_string(
|
||||||
|
messages,
|
||||||
|
human_prefix=self.human_prefix,
|
||||||
|
ai_prefix=self.ai_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
|
return await chain.apredict(summary=existing_summary, new_lines=new_lines)
|
||||||
|
|
||||||
|
|
||||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||||
"""Conversation summarizer to chat memory."""
|
"""Conversation summarizer to chat memory."""
|
||||||
|
@ -19,6 +19,11 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
"""String buffer of memory."""
|
"""String buffer of memory."""
|
||||||
return self.load_memory_variables({})[self.memory_key]
|
return self.load_memory_variables({})[self.memory_key]
|
||||||
|
|
||||||
|
async def abuffer(self) -> Union[str, List[BaseMessage]]:
|
||||||
|
"""Async memory buffer."""
|
||||||
|
memory_variables = await self.aload_memory_variables({})
|
||||||
|
return memory_variables[self.memory_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> List[str]:
|
||||||
"""Will always return list of memory variables.
|
"""Will always return list of memory variables.
|
||||||
@ -43,6 +48,22 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
)
|
)
|
||||||
return {self.memory_key: final_buffer}
|
return {self.memory_key: final_buffer}
|
||||||
|
|
||||||
|
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Asynchronously return key-value pairs given the text input to the chain."""
|
||||||
|
buffer = await self.chat_memory.aget_messages()
|
||||||
|
if self.moving_summary_buffer != "":
|
||||||
|
first_messages: List[BaseMessage] = [
|
||||||
|
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||||
|
]
|
||||||
|
buffer = first_messages + buffer
|
||||||
|
if self.return_messages:
|
||||||
|
final_buffer: Any = buffer
|
||||||
|
else:
|
||||||
|
final_buffer = get_buffer_string(
|
||||||
|
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
|
||||||
|
)
|
||||||
|
return {self.memory_key: final_buffer}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
||||||
"""Validate that prompt input variables are consistent."""
|
"""Validate that prompt input variables are consistent."""
|
||||||
@ -60,6 +81,13 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
super().save_context(inputs, outputs)
|
super().save_context(inputs, outputs)
|
||||||
self.prune()
|
self.prune()
|
||||||
|
|
||||||
|
async def asave_context(
|
||||||
|
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
"""Asynchronously save context from this conversation to buffer."""
|
||||||
|
await super().asave_context(inputs, outputs)
|
||||||
|
await self.aprune()
|
||||||
|
|
||||||
def prune(self) -> None:
|
def prune(self) -> None:
|
||||||
"""Prune buffer if it exceeds max token limit"""
|
"""Prune buffer if it exceeds max token limit"""
|
||||||
buffer = self.chat_memory.messages
|
buffer = self.chat_memory.messages
|
||||||
@ -73,7 +101,25 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
pruned_memory, self.moving_summary_buffer
|
pruned_memory, self.moving_summary_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aprune(self) -> None:
|
||||||
|
"""Asynchronously prune buffer if it exceeds max token limit"""
|
||||||
|
buffer = self.chat_memory.messages
|
||||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
|
if curr_buffer_length > self.max_token_limit:
|
||||||
|
pruned_memory = []
|
||||||
|
while curr_buffer_length > self.max_token_limit:
|
||||||
|
pruned_memory.append(buffer.pop(0))
|
||||||
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||||
|
self.moving_summary_buffer = await self.apredict_new_summary(
|
||||||
|
pruned_memory, self.moving_summary_buffer
|
||||||
|
)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear memory contents."""
|
"""Clear memory contents."""
|
||||||
super().clear()
|
super().clear()
|
||||||
self.moving_summary_buffer = ""
|
self.moving_summary_buffer = ""
|
||||||
|
|
||||||
|
async def aclear(self) -> None:
|
||||||
|
"""Asynchronously clear memory contents."""
|
||||||
|
await super().aclear()
|
||||||
|
self.moving_summary_buffer = ""
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
"""Test memory functionality."""
|
|
||||||
|
|
||||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
def test_summary_buffer_memory_no_buffer_yet() -> None:
|
|
||||||
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
|
||||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
|
||||||
output = memory.load_memory_variables({})
|
|
||||||
assert output == {"baz": ""}
|
|
||||||
|
|
||||||
|
|
||||||
def test_summary_buffer_memory_buffer_only() -> None:
|
|
||||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
|
||||||
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
|
||||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
|
||||||
assert memory.buffer == ["Human: bar\nAI: foo"]
|
|
||||||
output = memory.load_memory_variables({})
|
|
||||||
assert output == {"baz": "Human: bar\nAI: foo"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_summary_buffer_memory_summary() -> None:
|
|
||||||
"""Test ConversationSummaryBufferMemory when only buffer."""
|
|
||||||
memory = ConversationSummaryBufferMemory(
|
|
||||||
llm=FakeLLM(), memory_key="baz", max_token_limit=13
|
|
||||||
)
|
|
||||||
memory.save_context({"input": "bar"}, {"output": "foo"})
|
|
||||||
memory.save_context({"input": "bar1"}, {"output": "foo1"})
|
|
||||||
assert memory.buffer == ["Human: bar1\nAI: foo1"]
|
|
||||||
output = memory.load_memory_variables({})
|
|
||||||
assert output == {"baz": "foo\nHuman: bar1\nAI: foo1"}
|
|
@ -0,0 +1,62 @@
|
|||||||
|
"""Test memory functionality."""
|
||||||
|
|
||||||
|
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_buffer_memory_no_buffer_yet() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
||||||
|
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||||
|
output = memory.load_memory_variables({})
|
||||||
|
assert output == {"baz": ""}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summary_buffer_memory_no_buffer_yet_async() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when no inputs put in buffer yet."""
|
||||||
|
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||||
|
output = await memory.aload_memory_variables({})
|
||||||
|
assert output == {"baz": ""}
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_buffer_memory_buffer_only() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||||
|
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||||
|
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||||
|
assert memory.buffer == "Human: bar\nAI: foo"
|
||||||
|
output = memory.load_memory_variables({})
|
||||||
|
assert output == {"baz": "Human: bar\nAI: foo"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summary_buffer_memory_buffer_only_async() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||||
|
memory = ConversationSummaryBufferMemory(llm=FakeLLM(), memory_key="baz")
|
||||||
|
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||||
|
assert memory.buffer == "Human: bar\nAI: foo"
|
||||||
|
output = await memory.aload_memory_variables({})
|
||||||
|
assert output == {"baz": "Human: bar\nAI: foo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_buffer_memory_summary() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||||
|
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
|
||||||
|
memory = ConversationSummaryBufferMemory(
|
||||||
|
llm=llm, memory_key="baz", max_token_limit=5
|
||||||
|
)
|
||||||
|
memory.save_context({"input": "bar"}, {"output": "foo"})
|
||||||
|
memory.save_context({"input": "bar1"}, {"output": "foo1"})
|
||||||
|
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
|
||||||
|
output = memory.load_memory_variables({})
|
||||||
|
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summary_buffer_memory_summary_async() -> None:
|
||||||
|
"""Test ConversationSummaryBufferMemory when only buffer."""
|
||||||
|
llm = FakeLLM(queries={0: "summary"}, sequential_responses=True)
|
||||||
|
memory = ConversationSummaryBufferMemory(
|
||||||
|
llm=llm, memory_key="baz", max_token_limit=5
|
||||||
|
)
|
||||||
|
await memory.asave_context({"input": "bar"}, {"output": "foo"})
|
||||||
|
await memory.asave_context({"input": "bar1"}, {"output": "foo1"})
|
||||||
|
assert memory.buffer == "System: summary\nHuman: bar1\nAI: foo1"
|
||||||
|
output = await memory.aload_memory_variables({})
|
||||||
|
assert output == {"baz": "System: summary\nHuman: bar1\nAI: foo1"}
|
Loading…
Reference in New Issue
Block a user