mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
langchain[patch]: Add async methods to VectorStoreRetrieverMemory (#19408)
This commit is contained in:
parent
ef6d3d66d6
commit
1b813fe6fe
@ -39,6 +39,16 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _documents_to_memory_variables(
|
||||
self, docs: List[Document]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
@ -46,12 +56,16 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
async def aload_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = await self.retriever.aget_relevant_documents(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
def _form_documents(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
@ -73,5 +87,15 @@ class VectorStoreRetrieverMemory(BaseMemory):
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
async def asave_context(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
await self.retriever.aadd_documents(documents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
|
Loading…
Reference in New Issue
Block a user