Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
dc3c51acb0 fmt 2023-08-11 10:38:51 -07:00
Bagatur
8e91e5a18b fmt 2023-08-11 10:37:54 -07:00
Bagatur
20e7d8c463 rfc 2023-08-11 10:36:08 -07:00
2 changed files with 36 additions and 3 deletions

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.chat_message_histories import ZepChatMessageHistory
from langchain.memory.utils import get_prompt_input_key
class ZepMemory(ConversationBufferMemory):
class _ZepMemory(BaseChatMemory):
"""Persist your chain history to the Zep Memory Server.
The number of messages returned by Zep and when the Zep server summarizes chat
@@ -51,6 +53,7 @@ class ZepMemory(ConversationBufferMemory):
"""
chat_memory: ZepChatMessageHistory
memory_key: str = "history" #: :meta private:
def __init__(
self,
@@ -102,6 +105,14 @@ class ZepMemory(ConversationBufferMemory):
memory_key=memory_key,
)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def save_context(
self,
inputs: Dict[str, Any],
@@ -122,3 +133,25 @@ class ZepMemory(ConversationBufferMemory):
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)
class ZepSearchMemory(_ZepMemory):
top_k: int = 4
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
results = self.chat_memory.search(query, limit=self.top_k)
result = "\n".join([r.message.pop("content") for r in results])
return {self.memory_key: result}
class ZepBufferMemory(_ZepMemory, ConversationBufferMemory):
""""""

View File

@@ -42,7 +42,7 @@ class ZepRetriever(BaseRetriever):
try:
from zep_python import ZepClient
except ImportError:
raise ValueError(
raise ImportError(
"Could not import zep-python package. "
"Please install it with `pip install zep-python`."
)