diff --git a/docs/extras/integrations/memory/sql_chat_message_history.ipynb b/docs/extras/integrations/memory/sql_chat_message_history.ipynb new file mode 100644 index 00000000000..54a8d85b3e9 --- /dev/null +++ b/docs/extras/integrations/memory/sql_chat_message_history.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# SQL Chat Message History\n", + "\n", + "This notebook goes over a **SQLChatMessageHistory** class that allows to store chat history in any database supported by SQLAlchemy.\n", + "\n", + "Please note that to use it with databases other than SQLite, you will need to install the corresponding database driver." + ], + "metadata": { + "collapsed": false + }, + "id": "f22eab3f84cbeb37" + }, + { + "cell_type": "markdown", + "source": [ + "### Basic Usage\n", + "\n", + "To use the storage you need to provide only 2 things:\n", + "\n", + "1. Session Id - a unique identifier of the session, like user name, email, chat id etc.\n", + "2. Connection string - a string that specifies the database connection. It will be passed to SQLAlchemy create_engine function." + ], + "metadata": { + "collapsed": false + }, + "id": "f8f2830ee9ca1e01" + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "from langchain.memory.chat_message_histories import SQLChatMessageHistory\n", + "\n", + "chat_message_history = SQLChatMessageHistory(\n", + "\tsession_id='test_session',\n", + "\tconnection_string='sqlite:///sqlite.db'\n", + ")\n", + "\n", + "chat_message_history.add_user_message('Hello')\n", + "chat_message_history.add_ai_message('Hi')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.077748Z", + "start_time": "2023-08-28T10:04:36.105894Z" + } + }, + "id": "4576e914a866fb40" + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:38.929396Z", + "start_time": "2023-08-28T10:04:38.915727Z" + } + }, + "id": "b476688cbb32ba90" + }, + { + "cell_type": "markdown", + "source": [ + "### Custom Storage Format\n", + "\n", + "By default, only the session id and message dictionary are stored in the table.\n", + "\n", + "However, sometimes you might want to store some additional information, like message date, author, language etc.\n", + "\n", + "To do that, you can create a custom message converter, by implementing **BaseMessageConverter** interface." + ], + "metadata": { + "collapsed": false + }, + "id": "2e5337719d5614fd" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage\n", + "from typing import Any\n", + "from sqlalchemy import Column, Integer, Text, DateTime\n", + "from sqlalchemy.orm import declarative_base\n", + "from langchain.memory.chat_message_histories.sql import BaseMessageConverter\n", + "\n", + "\n", + "Base = declarative_base()\n", + "\n", + "\n", + "class CustomMessage(Base):\n", + "\t__tablename__ = 'custom_message_store'\n", + "\n", + "\tid = Column(Integer, primary_key=True)\n", + "\tsession_id = Column(Text)\n", + "\ttype = Column(Text)\n", + "\tcontent = Column(Text)\n", + "\tcreated_at = Column(DateTime)\n", + "\tauthor_email = Column(Text)\n", + "\n", + "\n", + "class CustomMessageConverter(BaseMessageConverter):\n", + "\tdef __init__(self, author_email: str):\n", + "\t\tself.author_email = author_email\n", + "\t\n", + "\tdef from_sql_model(self, sql_message: Any) -> BaseMessage:\n", + "\t\tif sql_message.type == 'human':\n", + "\t\t\treturn HumanMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == 'ai':\n", + "\t\t\treturn AIMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telif sql_message.type == 'system':\n", + "\t\t\treturn SystemMessage(\n", + "\t\t\t\tcontent=sql_message.content,\n", + "\t\t\t)\n", + "\t\telse:\n", + "\t\t\traise ValueError(f'Unknown message type: {sql_message.type}')\n", + "\t\n", + "\tdef to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n", + "\t\tnow = datetime.now()\n", + "\t\treturn CustomMessage(\n", + "\t\t\tsession_id=session_id,\n", + "\t\t\ttype=message.type,\n", + "\t\t\tcontent=message.content,\n", + "\t\t\tcreated_at=now,\n", + "\t\t\tauthor_email=self.author_email\n", + "\t\t)\n", + "\t\n", + "\tdef get_sql_model_class(self) -> Any:\n", + "\t\treturn CustomMessage\n", + "\n", + "\n", + "chat_message_history = SQLChatMessageHistory(\n", + "\tsession_id='test_session',\n", + "\tconnection_string='sqlite:///sqlite.db',\n", + "\tcustom_message_converter=CustomMessageConverter(\n", + "\t\tauthor_email='test@example.com'\n", + " )\n", + ")\n", + "\n", + "chat_message_history.add_user_message('Hello')\n", + "chat_message_history.add_ai_message('Hi')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:41.510498Z", + "start_time": "2023-08-28T10:04:41.494912Z" + } + }, + "id": "fdfde84c07d071bb" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "[HumanMessage(content='Hello', additional_kwargs={}, example=False),\n AIMessage(content='Hi', additional_kwargs={}, example=False)]" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_message_history.messages" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-28T10:04:43.497990Z", + "start_time": "2023-08-28T10:04:43.492517Z" + } + }, + "id": "4a6a54d8a9e2856f" + }, + { + "cell_type": "markdown", + "source": [ + "You also might want to change the name of session_id column. In this case you'll need to specify `session_id_field_name` parameter." + ], + "metadata": { + "collapsed": false + }, + "id": "622aded629a1adeb" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/memory/chat_message_histories/sql.py b/libs/langchain/langchain/memory/chat_message_histories/sql.py index 43fedd531d2..019ffd58b80 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/sql.py +++ b/libs/langchain/langchain/memory/chat_message_histories/sql.py @@ -1,6 +1,7 @@ import json import logging -from typing import List +from abc import ABC, abstractmethod +from typing import Any, List, Optional from sqlalchemy import Column, Integer, Text, create_engine @@ -18,6 +19,25 @@ from langchain.schema.messages import BaseMessage, _message_to_dict, messages_fr logger = logging.getLogger(__name__) +class BaseMessageConverter(ABC): + """The class responsible for converting BaseMessage to your SQLAlchemy model.""" + + @abstractmethod + def from_sql_model(self, sql_message: Any) -> BaseMessage: + """Convert a SQLAlchemy model to a BaseMessage instance.""" + raise NotImplementedError + + @abstractmethod + def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: + """Convert a BaseMessage instance to a SQLAlchemy model.""" + raise NotImplementedError + + @abstractmethod + def get_sql_model_class(self) -> Any: + """Get the SQLAlchemy model class.""" + raise NotImplementedError + + def create_message_model(table_name, DynamicBase): # type: ignore """ Create a message model for a given table name. @@ -41,6 +61,24 @@ def create_message_model(table_name, DynamicBase): # type: ignore return Message +class DefaultMessageConverter(BaseMessageConverter): + """The default message converter for SQLChatMessageHistory.""" + + def __init__(self, table_name: str): + self.model_class = create_message_model(table_name, declarative_base()) + + def from_sql_model(self, sql_message: Any) -> BaseMessage: + return messages_from_dict([json.loads(sql_message.message)])[0] + + def to_sql_model(self, message: BaseMessage, session_id: str) -> Any: + return self.model_class( + session_id=session_id, message=json.dumps(_message_to_dict(message)) + ) + + def get_sql_model_class(self) -> Any: + return self.model_class + + class SQLChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in an SQL database.""" @@ -49,44 +87,49 @@ class SQLChatMessageHistory(BaseChatMessageHistory): session_id: str, connection_string: str, table_name: str = "message_store", + session_id_field_name: str = "session_id", + custom_message_converter: Optional[BaseMessageConverter] = None, ): - self.table_name = table_name self.connection_string = connection_string self.engine = create_engine(connection_string, echo=False) + self.session_id_field_name = session_id_field_name + self.converter = custom_message_converter or DefaultMessageConverter(table_name) + self.sql_model_class = self.converter.get_sql_model_class() + if not hasattr(self.sql_model_class, session_id_field_name): + raise ValueError("SQL model class must have session_id column") self._create_table_if_not_exists() self.session_id = session_id self.Session = sessionmaker(self.engine) def _create_table_if_not_exists(self) -> None: - DynamicBase = declarative_base() - self.Message = create_message_model(self.table_name, DynamicBase) - # Create all does the check for us in case the table exists. - DynamicBase.metadata.create_all(self.engine) + self.sql_model_class.metadata.create_all(self.engine) @property def messages(self) -> List[BaseMessage]: # type: ignore """Retrieve all messages from db""" with self.Session() as session: - result = session.query(self.Message).where( - self.Message.session_id == self.session_id + result = session.query(self.sql_model_class).where( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id ) - items = [json.loads(record.message) for record in result] - messages = messages_from_dict(items) + messages = [] + for record in result: + messages.append(self.converter.from_sql_model(record)) return messages def add_message(self, message: BaseMessage) -> None: """Append the message to the record in db""" with self.Session() as session: - jsonstr = json.dumps(_message_to_dict(message)) - session.add(self.Message(session_id=self.session_id, message=jsonstr)) + session.add(self.converter.to_sql_model(message, self.session_id)) session.commit() def clear(self) -> None: """Clear session memory from db""" with self.Session() as session: - session.query(self.Message).filter( - self.Message.session_id == self.session_id + session.query(self.sql_model_class).filter( + getattr(self.sql_model_class, self.session_id_field_name) + == self.session_id ).delete() session.commit() diff --git a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py index 42cff47b044..a01e1b77dbe 100644 --- a/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py +++ b/libs/langchain/tests/unit_tests/memory/chat_message_histories/test_sql.py @@ -1,21 +1,26 @@ from pathlib import Path -from typing import Tuple +from typing import Any, Generator, Tuple import pytest +from sqlalchemy import Column, Integer, Text +from sqlalchemy.orm import DeclarativeBase from langchain.memory.chat_message_histories import SQLChatMessageHistory +from langchain.memory.chat_message_histories.sql import DefaultMessageConverter from langchain.schema.messages import AIMessage, HumanMessage -# @pytest.fixture(params=[("SQLite"), ("postgresql")]) -@pytest.fixture(params=[("SQLite")]) -def sql_histories(request, tmp_path: Path): # type: ignore - if request.param == "SQLite": - file_path = tmp_path / "db.sqlite3" - con_str = f"sqlite:///{file_path}" - elif request.param == "postgresql": - con_str = "postgresql://postgres:postgres@localhost/postgres" +@pytest.fixture() +def con_str(tmp_path: Path) -> str: + file_path = tmp_path / "db.sqlite3" + con_str = f"sqlite:///{file_path}" + return con_str + +@pytest.fixture() +def sql_histories( + con_str: str, +) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]: message_history = SQLChatMessageHistory( session_id="123", connection_string=con_str, table_name="test_table" ) @@ -24,7 +29,7 @@ def sql_histories(request, tmp_path: Path): # type: ignore session_id="456", connection_string=con_str, table_name="test_table" ) - yield (message_history, other_history) + yield message_history, other_history message_history.clear() other_history.clear() @@ -83,3 +88,24 @@ def test_clear_messages( sql_history.clear() assert len(sql_history.messages) == 0 assert len(other_history.messages) == 1 + + +def test_model_no_session_id_field_error(con_str: str) -> None: + class Base(DeclarativeBase): + pass + + class Model(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + test_field = Column(Text) + + class CustomMessageConverter(DefaultMessageConverter): + def get_sql_model_class(self) -> Any: + return Model + + with pytest.raises(ValueError): + SQLChatMessageHistory( + "test", + con_str, + custom_message_converter=CustomMessageConverter("test_table"), + )