mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
Extend SQLChatMessageHistory (#9849)
### Description There is a really nice class for saving chat messages into a database - SQLChatMessageHistory. It leverages SqlAlchemy to be compatible with any supported database (in contrast with PostgresChatMessageHistory, which is basically the same but is limited to Postgres). However, the class is not really customizable in terms of what you can store. I can imagine a lot of use cases, when one will need to save a message date, along with some additional metadata. To solve this, I propose to extract the converting logic from BaseMessage to SQLAlchemy model (and vice versa) into a separate class - message converter. So instead of rewriting the whole SQLChatMessageHistory class, a user will only need to write a custom model and a simple mapping class, and pass its instance as a parameter. I also noticed that there is no documentation on this class, so I added that too, with an example of custom message converter. ### Issue N/A ### Dependencies N/A ### Tag maintainer Not yet ### Twitter handle N/A
This commit is contained in:
committed by
GitHub
parent
fed137a8a9
commit
507e46844e
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, create_engine
|
||||
|
||||
@@ -18,6 +19,25 @@ from langchain.schema.messages import BaseMessage, _message_to_dict, messages_fr
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseMessageConverter(ABC):
|
||||
"""The class responsible for converting BaseMessage to your SQLAlchemy model."""
|
||||
|
||||
@abstractmethod
|
||||
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
||||
"""Convert a SQLAlchemy model to a BaseMessage instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
||||
"""Convert a BaseMessage instance to a SQLAlchemy model."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_sql_model_class(self) -> Any:
|
||||
"""Get the SQLAlchemy model class."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
"""
|
||||
Create a message model for a given table name.
|
||||
@@ -41,6 +61,24 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
||||
return Message
|
||||
|
||||
|
||||
class DefaultMessageConverter(BaseMessageConverter):
|
||||
"""The default message converter for SQLChatMessageHistory."""
|
||||
|
||||
def __init__(self, table_name: str):
|
||||
self.model_class = create_message_model(table_name, declarative_base())
|
||||
|
||||
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
||||
return messages_from_dict([json.loads(sql_message.message)])[0]
|
||||
|
||||
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
||||
return self.model_class(
|
||||
session_id=session_id, message=json.dumps(_message_to_dict(message))
|
||||
)
|
||||
|
||||
def get_sql_model_class(self) -> Any:
|
||||
return self.model_class
|
||||
|
||||
|
||||
class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in an SQL database."""
|
||||
|
||||
@@ -49,44 +87,49 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
session_id: str,
|
||||
connection_string: str,
|
||||
table_name: str = "message_store",
|
||||
session_id_field_name: str = "session_id",
|
||||
custom_message_converter: Optional[BaseMessageConverter] = None,
|
||||
):
|
||||
self.table_name = table_name
|
||||
self.connection_string = connection_string
|
||||
self.engine = create_engine(connection_string, echo=False)
|
||||
self.session_id_field_name = session_id_field_name
|
||||
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
|
||||
self.sql_model_class = self.converter.get_sql_model_class()
|
||||
if not hasattr(self.sql_model_class, session_id_field_name):
|
||||
raise ValueError("SQL model class must have session_id column")
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
self.session_id = session_id
|
||||
self.Session = sessionmaker(self.engine)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
DynamicBase = declarative_base()
|
||||
self.Message = create_message_model(self.table_name, DynamicBase)
|
||||
# Create all does the check for us in case the table exists.
|
||||
DynamicBase.metadata.create_all(self.engine)
|
||||
self.sql_model_class.metadata.create_all(self.engine)
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve all messages from db"""
|
||||
with self.Session() as session:
|
||||
result = session.query(self.Message).where(
|
||||
self.Message.session_id == self.session_id
|
||||
result = session.query(self.sql_model_class).where(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
)
|
||||
items = [json.loads(record.message) for record in result]
|
||||
messages = messages_from_dict(items)
|
||||
messages = []
|
||||
for record in result:
|
||||
messages.append(self.converter.from_sql_model(record))
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
jsonstr = json.dumps(_message_to_dict(message))
|
||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||
session.add(self.converter.to_sql_model(message, self.session_id))
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from db"""
|
||||
|
||||
with self.Session() as session:
|
||||
session.query(self.Message).filter(
|
||||
self.Message.session_id == self.session_id
|
||||
session.query(self.sql_model_class).filter(
|
||||
getattr(self.sql_model_class, self.session_id_field_name)
|
||||
== self.session_id
|
||||
).delete()
|
||||
session.commit()
|
||||
|
@@ -1,21 +1,26 @@
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Any, Generator, Tuple
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, Integer, Text
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from langchain.memory.chat_message_histories import SQLChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
|
||||
|
||||
# @pytest.fixture(params=[("SQLite"), ("postgresql")])
|
||||
@pytest.fixture(params=[("SQLite")])
|
||||
def sql_histories(request, tmp_path: Path): # type: ignore
|
||||
if request.param == "SQLite":
|
||||
file_path = tmp_path / "db.sqlite3"
|
||||
con_str = f"sqlite:///{file_path}"
|
||||
elif request.param == "postgresql":
|
||||
con_str = "postgresql://postgres:postgres@localhost/postgres"
|
||||
@pytest.fixture()
|
||||
def con_str(tmp_path: Path) -> str:
|
||||
file_path = tmp_path / "db.sqlite3"
|
||||
con_str = f"sqlite:///{file_path}"
|
||||
return con_str
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sql_histories(
|
||||
con_str: str,
|
||||
) -> Generator[Tuple[SQLChatMessageHistory, SQLChatMessageHistory], None, None]:
|
||||
message_history = SQLChatMessageHistory(
|
||||
session_id="123", connection_string=con_str, table_name="test_table"
|
||||
)
|
||||
@@ -24,7 +29,7 @@ def sql_histories(request, tmp_path: Path): # type: ignore
|
||||
session_id="456", connection_string=con_str, table_name="test_table"
|
||||
)
|
||||
|
||||
yield (message_history, other_history)
|
||||
yield message_history, other_history
|
||||
message_history.clear()
|
||||
other_history.clear()
|
||||
|
||||
@@ -83,3 +88,24 @@ def test_clear_messages(
|
||||
sql_history.clear()
|
||||
assert len(sql_history.messages) == 0
|
||||
assert len(other_history.messages) == 1
|
||||
|
||||
|
||||
def test_model_no_session_id_field_error(con_str: str) -> None:
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "test_table"
|
||||
id = Column(Integer, primary_key=True)
|
||||
test_field = Column(Text)
|
||||
|
||||
class CustomMessageConverter(DefaultMessageConverter):
|
||||
def get_sql_model_class(self) -> Any:
|
||||
return Model
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SQLChatMessageHistory(
|
||||
"test",
|
||||
con_str,
|
||||
custom_message_converter=CustomMessageConverter("test_table"),
|
||||
)
|
||||
|
Reference in New Issue
Block a user