From b9f5104e6cd94066c4e2e288742e4852a5408c59 Mon Sep 17 00:00:00 2001 From: Ian Date: Tue, 23 Jan 2024 05:56:56 +0800 Subject: [PATCH] communty[minor]: Store Message History to TiDB Database (#16304) This pull request integrates the TiDB database into LangChain for storing message history, marking one of several steps towards a comprehensive integration of TiDB with LangChain. A simple usage ```python from datetime import datetime from langchain_community.chat_message_histories import TiDBChatMessageHistory history = TiDBChatMessageHistory( connection_string="mysql+pymysql://:@:4000/?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true", session_id="code_gen", earliest_time=datetime.utcnow(), # Optional to set earliest_time to load messages after this time point. ) history.add_user_message("hi! How's feature going?") history.add_ai_message("It's almot done") ``` --- .../memory/tidb_chat_message_history.ipynb | 77 +++++++++ .../chat_message_histories/__init__.py | 2 + .../chat_message_histories/tidb.py | 148 ++++++++++++++++++ .../chat_message_histories/test_tidb.py | 101 ++++++++++++ 4 files changed, 328 insertions(+) create mode 100644 docs/docs/integrations/memory/tidb_chat_message_history.ipynb create mode 100644 libs/community/langchain_community/chat_message_histories/tidb.py create mode 100644 libs/community/tests/integration_tests/chat_message_histories/test_tidb.py diff --git a/docs/docs/integrations/memory/tidb_chat_message_history.ipynb b/docs/docs/integrations/memory/tidb_chat_message_history.ipynb new file mode 100644 index 00000000000..8a49af973d8 --- /dev/null +++ b/docs/docs/integrations/memory/tidb_chat_message_history.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TiDB\n", + "\n", + "> [TiDB](https://github.com/pingcap/tidb) is an open-source, cloud-native, distributed, MySQL-Compatible database for elastic scale and real-time analytics.\n", + "\n", + "This notebook introduces how to use TiDB to store chat message history. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "from langchain_community.chat_message_histories import TiDBChatMessageHistory\n", + "\n", + "history = TiDBChatMessageHistory(\n", + " connection_string=\"mysql+pymysql://:@:4000/?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true\",\n", + " session_id=\"code_gen\",\n", + " earliest_time=datetime.utcnow(), # Optional to set earliest_time to load messages after this time point.\n", + ")\n", + "\n", + "history.add_user_message(\"hi! How's feature going?\")\n", + "history.add_ai_message(\"It's almot done\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[HumanMessage(content=\"hi! How's feature going?\"),\n", + " AIMessage(content=\"It's almot done\")]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "history.messages" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_message_histories/__init__.py b/libs/community/langchain_community/chat_message_histories/__init__.py index a45ecb7ead6..8803810c1bb 100644 --- a/libs/community/langchain_community/chat_message_histories/__init__.py +++ b/libs/community/langchain_community/chat_message_histories/__init__.py @@ -35,6 +35,7 @@ from langchain_community.chat_message_histories.sql import SQLChatMessageHistory from langchain_community.chat_message_histories.streamlit import ( StreamlitChatMessageHistory, ) +from langchain_community.chat_message_histories.tidb import TiDBChatMessageHistory from langchain_community.chat_message_histories.upstash_redis import ( UpstashRedisChatMessageHistory, ) @@ -62,4 +63,5 @@ __all__ = [ "ZepChatMessageHistory", "UpstashRedisChatMessageHistory", "Neo4jChatMessageHistory", + "TiDBChatMessageHistory", ] diff --git a/libs/community/langchain_community/chat_message_histories/tidb.py b/libs/community/langchain_community/chat_message_histories/tidb.py new file mode 100644 index 00000000000..bfa36ad06ff --- /dev/null +++ b/libs/community/langchain_community/chat_message_histories/tidb.py @@ -0,0 +1,148 @@ +import json +import logging +from datetime import datetime +from typing import List, Optional + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict +from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker + +logger = logging.getLogger(__name__) + + +class TiDBChatMessageHistory(BaseChatMessageHistory): + """ + Represents a chat message history stored in a TiDB database. + """ + + def __init__( + self, + session_id: str, + connection_string: str, + table_name: str = "langchain_message_store", + earliest_time: Optional[datetime] = None, + ): + """ + Initializes a new instance of the TiDBChatMessageHistory class. + + Args: + session_id (str): The ID of the chat session. + connection_string (str): The connection string for the TiDB database. + format: mysql+pymysql://:@:4000/?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true + table_name (str, optional): the table name to store the chat messages. + Defaults to "langchain_message_store". + earliest_time (Optional[datetime], optional): The earliest time to retrieve messages from. + Defaults to None. + """ # noqa + + self.session_id = session_id + self.table_name = table_name + self.earliest_time = earliest_time + self.cache = [] + + # Set up SQLAlchemy engine and session + self.engine = create_engine(connection_string) + Session = sessionmaker(bind=self.engine) + self.session = Session() + + self._create_table_if_not_exists() + self._load_messages_to_cache() + + def _create_table_if_not_exists(self) -> None: + """ + Creates a table if it does not already exist in the database. + """ + + create_table_query = text( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id INT AUTO_INCREMENT PRIMARY KEY, + session_id VARCHAR(255) NOT NULL, + message JSON NOT NULL, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + INDEX session_idx (session_id) + );""" + ) + try: + self.session.execute(create_table_query) + self.session.commit() + except SQLAlchemyError as e: + logger.error(f"Error creating table: {e}") + self.session.rollback() + + def _load_messages_to_cache(self) -> None: + """ + Loads messages from the database into the cache. + + This method retrieves messages from the database table. The retrieved messages + are then stored in the cache for faster access. + + Raises: + SQLAlchemyError: If there is an error executing the database query. + + """ + time_condition = ( + f"AND create_time >= '{self.earliest_time}'" if self.earliest_time else "" + ) + query = text( + f""" + SELECT message FROM {self.table_name} + WHERE session_id = :session_id {time_condition} + ORDER BY id; + """ + ) + try: + result = self.session.execute(query, {"session_id": self.session_id}) + for record in result.fetchall(): + message_dict = json.loads(record[0]) + self.cache.append(messages_from_dict([message_dict])[0]) + except SQLAlchemyError as e: + logger.error(f"Error loading messages to cache: {e}") + + @property + def messages(self) -> List[BaseMessage]: + """returns all messages""" + if len(self.cache) == 0: + self.reload_cache() + return self.cache + + def add_message(self, message: BaseMessage) -> None: + """adds a message to the database and cache""" + query = text( + f"INSERT INTO {self.table_name} (session_id, message) VALUES (:session_id, :message);" # noqa + ) + try: + self.session.execute( + query, + { + "session_id": self.session_id, + "message": json.dumps(message_to_dict(message)), + }, + ) + self.session.commit() + self.cache.append(message) + except SQLAlchemyError as e: + logger.error(f"Error adding message: {e}") + self.session.rollback() + + def clear(self) -> None: + """clears all messages""" + query = text(f"DELETE FROM {self.table_name} WHERE session_id = :session_id;") + try: + self.session.execute(query, {"session_id": self.session_id}) + self.session.commit() + self.cache.clear() + except SQLAlchemyError as e: + logger.error(f"Error clearing messages: {e}") + self.session.rollback() + + def reload_cache(self) -> None: + """reloads messages from database to cache""" + self.cache.clear() + self._load_messages_to_cache() + + def __del__(self) -> None: + """closes the session""" + self.session.close() diff --git a/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py b/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py new file mode 100644 index 00000000000..17601af48b5 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_message_histories/test_tidb.py @@ -0,0 +1,101 @@ +import os + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_community.chat_message_histories import TiDBChatMessageHistory + +try: + CONNECTION_STRING = os.getenv("TEST_TiDB_CHAT_URL", "") + + if CONNECTION_STRING == "": + raise OSError("TEST_TiDB_URL environment variable is not set") + + tidb_available = True +except (OSError, ImportError): + tidb_available = False + + +@pytest.mark.skipif(not tidb_available, reason="tidb is not available") +def test_add_messages() -> None: + """Basic testing: adding messages to the TiDBChatMessageHistory.""" + message_store = TiDBChatMessageHistory("23334", CONNECTION_STRING) + message_store.clear() + assert len(message_store.messages) == 0 + message_store.add_user_message("Hello! Language Chain!") + message_store.add_ai_message("Hi Guys!") + + # create another message store to check if the messages are stored correctly + message_store_another = TiDBChatMessageHistory("46666", CONNECTION_STRING) + message_store_another.clear() + assert len(message_store_another.messages) == 0 + message_store_another.add_user_message("Hello! Bot!") + message_store_another.add_ai_message("Hi there!") + message_store_another.add_user_message("How's this pr going?") + + # Now check if the messages are stored in the database correctly + assert len(message_store.messages) == 2 + assert isinstance(message_store.messages[0], HumanMessage) + assert isinstance(message_store.messages[1], AIMessage) + assert message_store.messages[0].content == "Hello! Language Chain!" + assert message_store.messages[1].content == "Hi Guys!" + + assert len(message_store_another.messages) == 3 + assert isinstance(message_store_another.messages[0], HumanMessage) + assert isinstance(message_store_another.messages[1], AIMessage) + assert isinstance(message_store_another.messages[2], HumanMessage) + assert message_store_another.messages[0].content == "Hello! Bot!" + assert message_store_another.messages[1].content == "Hi there!" + assert message_store_another.messages[2].content == "How's this pr going?" + + # Now clear the first history + message_store.clear() + assert len(message_store.messages) == 0 + assert len(message_store_another.messages) == 3 + message_store_another.clear() + assert len(message_store.messages) == 0 + assert len(message_store_another.messages) == 0 + + +def test_tidb_recent_chat_message(): + """Test the TiDBChatMessageHistory with earliest_time parameter.""" + import time + from datetime import datetime + + # prepare some messages + message_store = TiDBChatMessageHistory("2333", CONNECTION_STRING) + message_store.clear() + assert len(message_store.messages) == 0 + message_store.add_user_message("Hello! Language Chain!") + message_store.add_ai_message("Hi Guys!") + assert len(message_store.messages) == 2 + assert isinstance(message_store.messages[0], HumanMessage) + assert isinstance(message_store.messages[1], AIMessage) + assert message_store.messages[0].content == "Hello! Language Chain!" + assert message_store.messages[1].content == "Hi Guys!" + + # now we add some recent messages to the database + earliest_time = datetime.utcnow() + time.sleep(1) + + message_store.add_user_message("How's this pr going?") + message_store.add_ai_message("It's almost done!") + assert len(message_store.messages) == 4 + assert isinstance(message_store.messages[2], HumanMessage) + assert isinstance(message_store.messages[3], AIMessage) + assert message_store.messages[2].content == "How's this pr going?" + assert message_store.messages[3].content == "It's almost done!" + + # now we create another message store with earliest_time parameter + message_store_another = TiDBChatMessageHistory( + "2333", CONNECTION_STRING, earliest_time=earliest_time + ) + assert len(message_store_another.messages) == 2 + assert isinstance(message_store_another.messages[0], HumanMessage) + assert isinstance(message_store_another.messages[1], AIMessage) + assert message_store_another.messages[0].content == "How's this pr going?" + assert message_store_another.messages[1].content == "It's almost done!" + + # now we clear the message store + message_store.clear() + assert len(message_store.messages) == 0