diff --git a/libs/community/langchain_community/chat_message_histories/file.py b/libs/community/langchain_community/chat_message_histories/file.py index d6f2f43c3d6..41dbd2afaad 100644 --- a/libs/community/langchain_community/chat_message_histories/file.py +++ b/libs/community/langchain_community/chat_message_histories/file.py @@ -1,45 +1,5 @@ -import json -import logging -from pathlib import Path -from typing import List +from langchain_core.chat_history import FileChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - messages_from_dict, - messages_to_dict, -) - -logger = logging.getLogger(__name__) - - -class FileChatMessageHistory(BaseChatMessageHistory): - """ - Chat message history that stores history in a local file. - - Args: - file_path: path of the local file to store the messages. - """ - - def __init__(self, file_path: str): - self.file_path = Path(file_path) - if not self.file_path.exists(): - self.file_path.touch() - self.file_path.write_text(json.dumps([])) - - @property - def messages(self) -> List[BaseMessage]: # type: ignore - """Retrieve the messages from the local file""" - items = json.loads(self.file_path.read_text()) - messages = messages_from_dict(items) - return messages - - def add_message(self, message: BaseMessage) -> None: - """Append the message to the record in the local file""" - messages = messages_to_dict(self.messages) - messages.append(messages_to_dict([message])[0]) - self.file_path.write_text(json.dumps(messages)) - - def clear(self) -> None: - """Clear session memory from the local file""" - self.file_path.write_text(json.dumps([])) +__all__ = [ + "FileChatMessageHistory", +] diff --git a/libs/community/langchain_community/chat_message_histories/in_memory.py b/libs/community/langchain_community/chat_message_histories/in_memory.py index fe6c6406524..679c9ce665e 100644 --- a/libs/community/langchain_community/chat_message_histories/in_memory.py +++ b/libs/community/langchain_community/chat_message_histories/in_memory.py @@ -1,31 +1,5 @@ -from typing import List, Sequence +from langchain_core.chat_history import InMemoryChatMessageHistory as ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel, Field - - -class ChatMessageHistory(BaseChatMessageHistory, BaseModel): - """In memory implementation of chat message history. - - Stores messages in an in memory list. - """ - - messages: List[BaseMessage] = Field(default_factory=list) - - async def aget_messages(self) -> List[BaseMessage]: - return self.messages - - def add_message(self, message: BaseMessage) -> None: - """Add a self-created message to the store""" - self.messages.append(message) - - async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: - """Add messages to the store""" - self.add_messages(messages) - - def clear(self) -> None: - self.messages = [] - - async def aclear(self) -> None: - self.clear() +__all__ = [ + "ChatMessageHistory", +] diff --git a/libs/core/langchain_core/chat_history.py b/libs/core/langchain_core/chat_history.py index 8f930745585..4388b373305 100644 --- a/libs/core/langchain_core/chat_history.py +++ b/libs/core/langchain_core/chat_history.py @@ -16,7 +16,9 @@ """ # noqa: E501 from __future__ import annotations +import json from abc import ABC, abstractmethod +from pathlib import Path from typing import List, Sequence, Union from langchain_core.messages import ( @@ -24,7 +26,10 @@ from langchain_core.messages import ( BaseMessage, HumanMessage, get_buffer_string, + messages_from_dict, + messages_to_dict, ) +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import run_in_executor @@ -184,3 +189,61 @@ class BaseChatMessageHistory(ABC): def __str__(self) -> str: """Return a string representation of the chat history.""" return get_buffer_string(self.messages) + + +class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history. + + Stores messages in an in memory list. + """ + + messages: List[BaseMessage] = Field(default_factory=list) + + async def aget_messages(self) -> List[BaseMessage]: + return self.messages + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store""" + self.messages.append(message) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Add messages to the store""" + self.add_messages(messages) + + def clear(self) -> None: + self.messages = [] + + async def aclear(self) -> None: + self.clear() + + +class FileChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in a local file.""" + + def __init__(self, file_path: str) -> None: + """Initialize the file path for the chat history. + + Args: + file_path: The path to the local file to store the chat history. + """ + self.file_path = Path(file_path) + if not self.file_path.exists(): + self.file_path.touch() + self.file_path.write_text(json.dumps([])) + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from the local file""" + items = json.loads(self.file_path.read_text()) + messages = messages_from_dict(items) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in the local file""" + messages = messages_to_dict(self.messages) + messages.append(messages_to_dict([message])[0]) + self.file_path.write_text(json.dumps(messages)) + + def clear(self) -> None: + """Clear session memory from the local file""" + self.file_path.write_text(json.dumps([])) diff --git a/libs/community/tests/unit_tests/chat_message_histories/test_file.py b/libs/core/tests/unit_tests/chat_history/test_file_chat_message_history.py similarity index 97% rename from libs/community/tests/unit_tests/chat_message_histories/test_file.py rename to libs/core/tests/unit_tests/chat_history/test_file_chat_message_history.py index f069ff24935..4c292c61e5a 100644 --- a/libs/community/tests/unit_tests/chat_message_histories/test_file.py +++ b/libs/core/tests/unit_tests/chat_history/test_file_chat_message_history.py @@ -3,9 +3,9 @@ from pathlib import Path from typing import Generator import pytest -from langchain_core.messages import AIMessage, HumanMessage -from langchain_community.chat_message_histories import FileChatMessageHistory +from langchain_core.chat_history import FileChatMessageHistory +from langchain_core.messages import AIMessage, HumanMessage @pytest.fixture diff --git a/libs/langchain/langchain/memory/chat_memory.py b/libs/langchain/langchain/memory/chat_memory.py index 671edf9f31b..10feaa3e1b9 100644 --- a/libs/langchain/langchain/memory/chat_memory.py +++ b/libs/langchain/langchain/memory/chat_memory.py @@ -2,8 +2,10 @@ import warnings from abc import ABC from typing import Any, Dict, Optional, Tuple -from langchain_community.chat_message_histories.in_memory import ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.chat_history import ( + BaseChatMessageHistory, + InMemoryChatMessageHistory, +) from langchain_core.memory import BaseMemory from langchain_core.messages import AIMessage, HumanMessage from langchain_core.pydantic_v1 import Field @@ -14,7 +16,9 @@ from langchain.memory.utils import get_prompt_input_key class BaseChatMemory(BaseMemory, ABC): """Abstract base class for chat memory.""" - chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory) + chat_memory: BaseChatMessageHistory = Field( + default_factory=InMemoryChatMessageHistory + ) output_key: Optional[str] = None input_key: Optional[str] = None return_messages: bool = False