Compare commits

...

5 Commits

Author SHA1 Message Date
Eugene Yurtsev
354c382875 x 2023-12-05 11:44:05 -05:00
Eugene Yurtsev
05d53073cf x 2023-12-05 11:43:43 -05:00
Eugene Yurtsev
11f52cbee5 Merge branch 'master' into eugene/update_file_chat_memory 2023-12-05 11:20:03 -05:00
Eugene Yurtsev
ff023e96dd x 2023-12-05 10:36:42 -05:00
Eugene Yurtsev
78ae25d6b7 x 2023-12-04 22:58:32 -05:00
2 changed files with 101 additions and 10 deletions

View File

@@ -1,7 +1,12 @@
"""Chat message history that stores history in a local file.
This chat history is mainly useful for testing / prototyping purposes.
"""
import json
import logging
import re
from pathlib import Path
from typing import List
from typing import Callable, List, Type, TypeVar, Union
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
@@ -12,20 +17,48 @@ from langchain_core.messages import (
logger = logging.getLogger(__name__)
Self = TypeVar("Self", bound="FileChatMessageHistory")
def _is_valid_session_id(session_id: str) -> bool:
"""Check if the session ID is in a valid format."""
# Use a regular expression to match the allowed characters
valid_characters = re.compile(r"^[a-zA-Z0-9-_]+$")
return bool(valid_characters.match(session_id))
class FileChatMessageHistory(BaseChatMessageHistory):
"""
Chat message history that stores history in a local file.
"""Chat message history that stores history in a local file.
Args:
file_path: path of the local file to store the messages.
Examples:
.. code-block:: python
from langchain_core.messages import HumanMessage
from langchain.memory import FileChatMessageHistory
history = FileHumanMessageHistory("history.json")
history.add_message(HumanMessage("Hello, world!"))
history.messages # [HumanMessage("Hello, world!")]
history.clear()
history.messages # []
"""
def __init__(self, file_path: str):
self.file_path = Path(file_path)
if not self.file_path.exists():
self.file_path.touch()
self.file_path.write_text(json.dumps([]))
def __init__(self, file_path: Union[str, Path]) -> None:
"""Chat message history that stores history in a local file.
Args:
file_path: path of the local file to store the messages
"""
if isinstance(file_path, str):
file_path = Path(file_path)
if not file_path.exists():
file_path.touch()
file_path.write_text(json.dumps([]))
self.file_path = file_path
@property
def messages(self) -> List[BaseMessage]: # type: ignore
@@ -43,3 +76,31 @@ class FileChatMessageHistory(BaseChatMessageHistory):
def clear(self) -> None:
"""Clear session memory from the local file"""
self.file_path.write_text(json.dumps([]))
@classmethod
def create_session_factory(
cls: Type[Self], base_dir: Union[str, Path]
) -> Callable[[str], Self]:
"""Create a session ID factory that creates session IDs from a base dir.
Args:
base_dir: Base directory to use for storing the chat histories.
Returns:
A session ID factory that creates session IDs from a base path.
"""
base_dir_ = Path(base_dir) if isinstance(base_dir, str) else base_dir
def get_chat_history(session_id: str) -> Self:
"""Get a chat history from a session ID."""
if not _is_valid_session_id(session_id):
raise ValueError(
f"Session ID {session_id} is not in a valid format. "
"Session ID must only contain alphanumeric characters, "
"hyphens, and underscores."
)
file_path = base_dir_ / f"{session_id}.json"
return cls(file_path)
return get_chat_history

View File

@@ -16,6 +16,13 @@ def file_chat_message_history() -> Generator[FileChatMessageHistory, None, None]
yield file_chat_message_history
@pytest.fixture()
def base_dir() -> Generator[Path, None, None]:
"""Yield a temporary directory."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
def test_add_messages(file_chat_message_history: FileChatMessageHistory) -> None:
file_chat_message_history.add_user_message("Hello!")
file_chat_message_history.add_ai_message("Hi there!")
@@ -69,3 +76,26 @@ def test_multiple_sessions(file_chat_message_history: FileChatMessageHistory) ->
assert messages[2].content == "Tell me a joke."
expected_content = "Why did the chicken cross the road? To get to the other side!"
assert messages[3].content == expected_content
def test_session_factory(base_dir: Path) -> None:
"""Test that the session factory works as expected."""
message_history_factory = FileChatMessageHistory.create_session_factory(base_dir)
session_1 = message_history_factory("session_1")
assert session_1.messages == []
session_1.add_message(HumanMessage(content="Hello!"))
assert session_1.messages == [HumanMessage(content="Hello!")]
session_2 = message_history_factory("session_2")
assert session_2.messages == []
session_2.add_message(HumanMessage(content="Goodbye!"))
session_2.add_message(HumanMessage(content="Meow!"))
assert session_2.messages == [
HumanMessage(content="Goodbye!"),
HumanMessage(content="Meow!"),
]
# Make sure that session 1 is not affected
assert session_1.messages == [HumanMessage(content="Hello!")]
assert sorted(str(p.name) for p in base_dir.glob("*.json")) == [
"session_1.json",
"session_2.json",
]