mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Add StreamlitChatMessageHistory (#8497)
Add a StreamlitChatMessageHistory class that stores chat messages in [Streamlit's Session State](https://docs.streamlit.io/library/api-reference/session-state). Note: The integration test uses a currently-experimental Streamlit testing framework to simulate the execution of a Streamlit app. Marking this PR as draft until I confirm with the Streamlit team that we're comfortable supporting it. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
8961c720b8
commit
6705928b9d
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "91c6a7ef",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Streamlit Chat Message History\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Streamlit to store chat message history. Note, StreamlitChatMessageHistory only works when run in a Streamlit app. For more on Streamlit check out their\n",
|
||||||
|
"[getting started documentation](https://docs.streamlit.io/library/get-started)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d15e3302",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.memory import StreamlitChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"history = StreamlitChatMessageHistory(\"foo\")\n",
|
||||||
|
"\n",
|
||||||
|
"history.add_user_message(\"hi!\")\n",
|
||||||
|
"history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "64fc465e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"history.messages"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "poetry-venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "poetry-venv"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -42,6 +42,7 @@ from langchain.memory.chat_message_histories import (
|
|||||||
PostgresChatMessageHistory,
|
PostgresChatMessageHistory,
|
||||||
RedisChatMessageHistory,
|
RedisChatMessageHistory,
|
||||||
SQLChatMessageHistory,
|
SQLChatMessageHistory,
|
||||||
|
StreamlitChatMessageHistory,
|
||||||
ZepChatMessageHistory,
|
ZepChatMessageHistory,
|
||||||
)
|
)
|
||||||
from langchain.memory.combined import CombinedMemory
|
from langchain.memory.combined import CombinedMemory
|
||||||
@ -87,6 +88,7 @@ __all__ = [
|
|||||||
"SQLChatMessageHistory",
|
"SQLChatMessageHistory",
|
||||||
"SQLiteEntityStore",
|
"SQLiteEntityStore",
|
||||||
"SimpleMemory",
|
"SimpleMemory",
|
||||||
|
"StreamlitChatMessageHistory",
|
||||||
"VectorStoreRetrieverMemory",
|
"VectorStoreRetrieverMemory",
|
||||||
"ZepChatMessageHistory",
|
"ZepChatMessageHistory",
|
||||||
"ZepMemory",
|
"ZepMemory",
|
||||||
|
@ -13,6 +13,9 @@ from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHi
|
|||||||
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
from langchain.memory.chat_message_histories.sql import SQLChatMessageHistory
|
||||||
|
from langchain.memory.chat_message_histories.streamlit import (
|
||||||
|
StreamlitChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
|
from langchain.memory.chat_message_histories.zep import ZepChatMessageHistory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -27,5 +30,6 @@ __all__ = [
|
|||||||
"PostgresChatMessageHistory",
|
"PostgresChatMessageHistory",
|
||||||
"RedisChatMessageHistory",
|
"RedisChatMessageHistory",
|
||||||
"SQLChatMessageHistory",
|
"SQLChatMessageHistory",
|
||||||
|
"StreamlitChatMessageHistory",
|
||||||
"ZepChatMessageHistory",
|
"ZepChatMessageHistory",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseChatMessageHistory,
|
||||||
|
)
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
|
class StreamlitChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""
|
||||||
|
Chat message history that stores messages in Streamlit session state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to use in Streamlit session state for storing messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key: str = "langchain_messages"):
|
||||||
|
try:
|
||||||
|
import streamlit as st
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Unable to import streamlit, please run `pip install streamlit`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if key not in st.session_state:
|
||||||
|
st.session_state[key] = []
|
||||||
|
self._messages = st.session_state[key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||||
|
"""Retrieve the current list of messages"""
|
||||||
|
return self._messages
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a message to the session memory"""
|
||||||
|
self._messages.append(message)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory"""
|
||||||
|
self._messages.clear()
|
@ -0,0 +1,64 @@
|
|||||||
|
"""Unit tests for StreamlitChatMessageHistory functionality."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
test_script = """
|
||||||
|
import json
|
||||||
|
import streamlit as st
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
|
||||||
|
from langchain.schema.messages import _message_to_dict
|
||||||
|
|
||||||
|
message_history = StreamlitChatMessageHistory()
|
||||||
|
memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True)
|
||||||
|
|
||||||
|
# Add some messages
|
||||||
|
if st.checkbox("add initial messages", value=True):
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
else:
|
||||||
|
st.markdown("Skipped add")
|
||||||
|
|
||||||
|
# Clear messages if checked
|
||||||
|
if st.checkbox("clear messages"):
|
||||||
|
st.markdown("Cleared!")
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
# Write the output to st.code as a json blob for inspection
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
st.text(messages_json)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("streamlit")
|
||||||
|
def test_memory_with_message_store() -> None:
|
||||||
|
try:
|
||||||
|
from streamlit.testing.script_interactions import InteractiveScriptTests
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pytest.skip("Incorrect version of Streamlit installed")
|
||||||
|
|
||||||
|
test_handler = InteractiveScriptTests()
|
||||||
|
test_handler.setUp()
|
||||||
|
try:
|
||||||
|
sr = test_handler.script_from_string(test_script).run()
|
||||||
|
except TypeError:
|
||||||
|
# Earlier version expected 2 arguments
|
||||||
|
sr = test_handler.script_from_string("memory_test.py", test_script).run()
|
||||||
|
|
||||||
|
# Initial run should write two messages
|
||||||
|
messages_json = sr.get("text")[-1].value
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# Uncheck the initial write, they should persist in session_state
|
||||||
|
sr = sr.get("checkbox")[0].uncheck().run()
|
||||||
|
assert sr.get("markdown")[0].value == "Skipped add"
|
||||||
|
messages_json = sr.get("text")[-1].value
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# Clear the message history
|
||||||
|
sr = sr.get("checkbox")[1].check().run()
|
||||||
|
assert sr.get("markdown")[1].value == "Cleared!"
|
||||||
|
messages_json = sr.get("text")[-1].value
|
||||||
|
assert messages_json == "[]"
|
Loading…
Reference in New Issue
Block a user