mv chat history (#9725)

This commit is contained in:
Bagatur 2023-08-29 21:41:32 -07:00 committed by GitHub
parent d762a6b51f
commit 2d2b097fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 64 deletions

View File

@ -1,8 +1,9 @@
"""**Schemas** are the LangChain Base Classes and Interfaces.""" """**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.document import BaseDocumentTransformer, Document from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.exceptions import LangChainException from langchain.schema.exceptions import LangChainException
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory from langchain.schema.memory import BaseMemory
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -40,10 +41,10 @@ Memory = BaseMemory
__all__ = [ __all__ = [
"BaseMemory", "BaseMemory",
"BaseStore", "BaseStore",
"BaseChatMessageHistory",
"AgentFinish", "AgentFinish",
"AgentAction", "AgentAction",
"Document", "Document",
"BaseChatMessageHistory",
"BaseDocumentTransformer", "BaseDocumentTransformer",
"BaseMessage", "BaseMessage",
"ChatMessage", "ChatMessage",

View File

@ -0,0 +1,67 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

View File

@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
class BaseMemory(Serializable, ABC): class BaseMemory(Serializable, ABC):
@ -58,64 +57,3 @@ class BaseMemory(Serializable, ABC):
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:
"""Clear memory contents.""" """Clear memory contents."""
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""