langchain[patch]: Add async methods to VectorStoreRetrieverMemory (#19408)

This commit is contained in:
Christophe Bornet 2024-03-22 23:44:24 +01:00 committed by GitHub
parent ef6d3d66d6
commit 1b813fe6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,6 +39,16 @@ class VectorStoreRetrieverMemory(BaseMemory):
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _documents_to_memory_variables(
self, docs: List[Document]
) -> Dict[str, Union[List[Document], str]]:
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}
def load_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
@ -46,12 +56,16 @@ class VectorStoreRetrieverMemory(BaseMemory):
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}
return self._documents_to_memory_variables(docs)
async def aload_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = await self.retriever.aget_relevant_documents(query)
return self._documents_to_memory_variables(docs)
def _form_documents(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
@ -73,5 +87,15 @@ class VectorStoreRetrieverMemory(BaseMemory):
documents = self._form_documents(inputs, outputs)
self.retriever.add_documents(documents)
async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
await self.retriever.aadd_documents(documents)
def clear(self) -> None:
"""Nothing to clear."""
async def aclear(self) -> None:
"""Nothing to clear."""