diff --git a/docs/extras/integrations/memory/xata_chat_message_history.ipynb b/docs/extras/integrations/memory/xata_chat_message_history.ipynb new file mode 100644 index 00000000000..938f6c44b90 --- /dev/null +++ b/docs/extras/integrations/memory/xata_chat_message_history.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Xata chat memory\n", + "\n", + "[Xata](https://xata.io) is a serverless data platform, based on PostgreSQL and Elasticsearch. It provides a Python SDK for interacting with your database, and a UI for managing your data. With the `XataChatMessageHistory` class, you can use Xata databases for longer-term persistence of chat sessions.\n", + "\n", + "This notebook covers:\n", + "\n", + "* A simple example showing what `XataChatMessageHistory` does.\n", + "* A more complex example using a REACT agent that answer questions based on a knowledge based or documentation (stored in Xata as a vector store) and also having a long-term searchable history of its past messages (stored in Xata as a memory store)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "### Create a database\n", + "\n", + "In the [Xata UI](https://app.xata.io) create a new database. You can name it whatever you want, in this notepad we'll use `langchain`. The Langchain integration can auto-create the table used for storying the memory, and this is what we'll use in this example. If you want to pre-create the table, ensure it has the right schema and set `create_table` to `False` when creating the class. Pre-creating the table saves one round-trip to the database during each session initialization." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's first install our dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install xata==1.0.0rc0 openai langchain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we need to get the environment variables for Xata. You can create a new API key by visiting your [account settings](https://app.xata.io/settings). To find the database URL, go to the Settings page of the database that you have created. The database URL should look something like this: `https://demo-uni3q8.eu-west-1.xata.sh/db/langchain`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "api_key = getpass.getpass(\"Xata API key: \")\n", + "db_url = input(\"Xata database URL (copy it from your DB settings):\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a simple memory store\n", + "\n", + "To test the memory store functionality in isolation, let's use the following code snippet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import XataChatMessageHistory\n", + "\n", + "history = XataChatMessageHistory(\n", + " session_id=\"session-1\",\n", + " api_key=api_key,\n", + " db_url=db_url,\n", + " table_name=\"memory\"\n", + ")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above code creates a session with the ID `session-1` and stores two messages in it. After running the above, if you visit the Xata UI, you should see a table named `memory` and the two messages added to it.\n", + "\n", + "You can retrieve the message history for a particular session with the following code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conversational Q&A chain on your data with memory\n", + "\n", + "Let's now see a more complex example in which we combine OpenAI, the Xata Vector Store integration, and the Xata memory store integration to create a Q&A chat bot on your data, with follow-up questions and history." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We're going to need to access the OpenAI API, so let's configure the API key:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To store the documents that the chatbot will search for answers, add a table named `docs` to your `langchain` database using the Xata UI, and add the following columns:\n", + "\n", + "* `content` of type \"Text\". This is used to store the `Document.pageContent` values.\n", + "* `embedding` of type \"Vector\". Use the dimension used by the model you plan to use. In this notebook we use OpenAI embeddings, which have 1536 dimensions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create the vector store and add some sample docs to it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.vectorstores.xata import XataVectorStore\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "texts = [\n", + " \"Xata is a Serverless Data platform based on PostgreSQL\",\n", + " \"Xata offers a built-in vector type that can be used to store and query vectors\",\n", + " \"Xata includes similarity search\"\n", + "]\n", + "\n", + "vector_store = XataVectorStore.from_texts(texts, embeddings, api_key=api_key, db_url=db_url, table_name=\"docs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After running the above command, if you go to the Xata UI, you should see the documents loaded together with their embeddings in the `docs` table." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now create a ConversationBufferMemory to store the chat messages from both the user and the AI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import ConversationBufferMemory\n", + "from uuid import uuid4\n", + "\n", + "chat_memory = XataChatMessageHistory(\n", + " session_id=str(uuid4()), # needs to be unique per user session\n", + " api_key=api_key,\n", + " db_url=db_url,\n", + " table_name=\"memory\"\n", + ")\n", + "memory = ConversationBufferMemory(memory_key=\"chat_history\", chat_memory=chat_memory, return_messages=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now it's time to create an Agent to use both the vector store and the chat memory together." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import initialize_agent, AgentType\n", + "from langchain.agents.agent_toolkits import create_retriever_tool\n", + "from langchain.chat_models import ChatOpenAI\n", + "\n", + "tool = create_retriever_tool(\n", + " vector_store.as_retriever(), \n", + " \"search_docs\",\n", + " \"Searches and returns documents from the Xata manual. Useful when you need to answer questions about Xata.\"\n", + ")\n", + "tools = [tool]\n", + "\n", + "llm = ChatOpenAI(temperature=0)\n", + "\n", + "agent = initialize_agent(\n", + " tools,\n", + " llm,\n", + " agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,\n", + " verbose=True,\n", + " memory=memory)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To test, let's tell the agent our name:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(input=\"My name is bob\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, let's now ask the agent some questions about Xata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(input=\"What is xata?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that it answers based on the data stored in the document store. And now, let's ask a follow up question:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(input=\"Does it support similarity search?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now let's test its memory:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(input=\"Did I tell you my name? What is it?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/langchain/langchain/memory/__init__.py b/libs/langchain/langchain/memory/__init__.py index e0ac6371fec..f2a61ef2f0d 100644 --- a/libs/langchain/langchain/memory/__init__.py +++ b/libs/langchain/langchain/memory/__init__.py @@ -43,6 +43,7 @@ from langchain.memory.chat_message_histories import ( RedisChatMessageHistory, SQLChatMessageHistory, StreamlitChatMessageHistory, + XataChatMessageHistory, ZepChatMessageHistory, ) from langchain.memory.combined import CombinedMemory @@ -90,6 +91,7 @@ __all__ = [ "SimpleMemory", "StreamlitChatMessageHistory", "VectorStoreRetrieverMemory", + "XataChatMessageHistory", "ZepChatMessageHistory", "ZepMemory", ] diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index 02241675b15..ddd23de4f80 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -17,6 +17,7 @@ from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory from langchain.memory.chat_message_histories.streamlit import ( StreamlitChatMessageHistory, ) +from langchain.memory.chat_message_histories.xata import XataChatMessageHistory from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory __all__ = [ @@ -33,5 +34,6 @@ __all__ = [ "RocksetChatMessageHistory", "SQLChatMessageHistory", "StreamlitChatMessageHistory", + "XataChatMessageHistory", "ZepChatMessageHistory", ] diff --git a/libs/langchain/langchain/memory/chat_message_histories/xata.py b/libs/langchain/langchain/memory/chat_message_histories/xata.py new file mode 100644 index 00000000000..de358888a0d --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/xata.py @@ -0,0 +1,132 @@ +import json +from typing import List + +from langchain.schema import ( + BaseChatMessageHistory, +) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict + + +class XataChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a Xata database.""" + + def __init__( + self, + session_id: str, + db_url: str, + api_key: str, + branch_name: str = "main", + table_name: str = "messages", + create_table: bool = True, + ) -> None: + """Initialize with Xata client.""" + try: + from xata.client import XataClient # noqa: F401 + except ImportError: + raise ValueError( + "Could not import xata python package. " + "Please install it with `pip install xata`." + ) + self._client = XataClient( + api_key=api_key, db_url=db_url, branch_name=branch_name + ) + self._table_name = table_name + self._session_id = session_id + + if create_table: + self._create_table_if_not_exists() + + def _create_table_if_not_exists(self) -> None: + r = self._client.table().get_schema(self._table_name) + if r.status_code <= 299: + return + if r.status_code != 404: + raise Exception( + f"Error checking if table exists in Xata: {r.status_code} {r}" + ) + r = self._client.table().create(self._table_name) + if r.status_code > 299: + raise Exception(f"Error creating table in Xata: {r.status_code} {r}") + r = self._client.table().set_schema( + self._table_name, + payload={ + "columns": [ + {"name": "sessionId", "type": "string"}, + {"name": "type", "type": "string"}, + {"name": "role", "type": "string"}, + {"name": "content", "type": "text"}, + {"name": "name", "type": "string"}, + {"name": "additionalKwargs", "type": "text"}, + ] + }, + ) + if r.status_code > 299: + raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}") + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the Xata table""" + msg = _message_to_dict(message) + r = self._client.records().insert( + self._table_name, + { + "sessionId": self._session_id, + "type": msg["type"], + "content": message.content, + "additionalKwargs": json.dumps(message.additional_kwargs), + "role": msg["data"].get("role"), + "name": msg["data"].get("name"), + }, + ) + if r.status_code > 299: + raise Exception(f"Error adding message to Xata: {r.status_code} {r}") + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + r = self._client.data().query( + self._table_name, + payload={ + "filter": { + "sessionId": self._session_id, + }, + "sort": {"xata.createdAt": "asc"}, + }, + ) + if r.status_code != 200: + raise Exception(f"Error running query: {r.status_code} {r}") + msgs = messages_from_dict( + [ + { + "type": m["type"], + "data": { + "content": m["content"], + "role": m.get("role"), + "name": m.get("name"), + "additionalKwargs": json.loads(m["additionalKwargs"]), + }, + } + for m in r["records"] + ] + ) + return msgs + + def clear(self) -> None: + """Delete session from Xata table.""" + while True: + r = self._client.data().query( + self._table_name, + payload={ + "columns": ["id"], + "filter": { + "sessionId": self._session_id, + }, + }, + ) + if r.status_code != 200: + raise Exception(f"Error running query: {r.status_code} {r}") + ids = [rec["id"] for rec in r["records"]] + if len(ids) == 0: + break + operations = [ + {"delete": {"table": self._table_name, "id": id}} for id in ids + ] + self._client.records().transaction(payload={"operations": operations}) diff --git a/libs/langchain/tests/integration_tests/memory/test_xata.py b/libs/langchain/tests/integration_tests/memory/test_xata.py new file mode 100644 index 00000000000..88bd158a257 --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_xata.py @@ -0,0 +1,41 @@ +"""Test Xata chat memory store functionality. + +Before running this test, please create a Xata database. +""" + +import json +import os + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import XataChatMessageHistory +from langchain.schema.messages import _message_to_dict + + +class TestXata: + @classmethod + def setup_class(cls) -> None: + assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set" + assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set" + + def test_xata_chat_memory(self) -> None: + message_history = XataChatMessageHistory( + api_key=os.getenv("XATA_API_KEY", ""), + db_url=os.getenv("XATA_DB_URL", ""), + session_id="integration-test-session", + ) + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Redis, so the next test run won't pick it up + memory.chat_memory.clear()