mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
Add ElasticsearchChatMessageHistory (#10932)
**Description** This PR adds the `ElasticsearchChatMessageHistory` implementation that stores chat message history in the configured [Elasticsearch](https://www.elastic.co/elasticsearch/) deployment. ```python from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory history = ElasticsearchChatMessageHistory( es_url="https://my-elasticsearch-deployment-url:9200", index="chat-history-index", session_id="123" ) history.add_ai_message("This is me, the AI") history.add_user_message("This is me, the human") ``` **Dependencies** - [elasticsearch client](https://elasticsearch-py.readthedocs.io/) required Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from langchain.memory.chat_message_histories import (
|
||||
ChatMessageHistory,
|
||||
CosmosDBChatMessageHistory,
|
||||
DynamoDBChatMessageHistory,
|
||||
ElasticsearchChatMessageHistory,
|
||||
FileChatMessageHistory,
|
||||
MomentoChatMessageHistory,
|
||||
MongoDBChatMessageHistory,
|
||||
@@ -77,6 +78,7 @@ __all__ = [
|
||||
"ConversationTokenBufferMemory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"InMemoryEntityStore",
|
||||
"MomentoChatMessageHistory",
|
||||
|
@@ -3,6 +3,9 @@ from langchain.memory.chat_message_histories.cassandra import (
|
||||
)
|
||||
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.elasticsearch import (
|
||||
ElasticsearchChatMessageHistory,
|
||||
)
|
||||
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||
from langchain.memory.chat_message_histories.firestore import (
|
||||
FirestoreChatMessageHistory,
|
||||
@@ -25,6 +28,7 @@ __all__ = [
|
||||
"CassandraChatMessageHistory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"FirestoreChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
|
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Elasticsearch.
|
||||
|
||||
Args:
|
||||
es_url: URL of the Elasticsearch instance to connect to.
|
||||
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||
es_user: Username to use when connecting to Elasticsearch.
|
||||
es_password: Password to use when connecting to Elasticsearch.
|
||||
es_api_key: API key to use when connecting to Elasticsearch.
|
||||
es_connection: Optional pre-existing Elasticsearch connection.
|
||||
index: Name of the index to use.
|
||||
session_id: Arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: str,
|
||||
session_id: str,
|
||||
*,
|
||||
es_connection: Optional["Elasticsearch"] = None,
|
||||
es_url: Optional[str] = None,
|
||||
es_cloud_id: Optional[str] = None,
|
||||
es_user: Optional[str] = None,
|
||||
es_api_key: Optional[str] = None,
|
||||
es_password: Optional[str] = None,
|
||||
):
|
||||
self.index: str = index
|
||||
self.session_id: str = session_id
|
||||
|
||||
# Initialize Elasticsearch client from passed client arg or connection info
|
||||
if es_connection is not None:
|
||||
self.client = es_connection.options(
|
||||
headers={"user-agent": self.get_user_agent()}
|
||||
)
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Either provide a pre-existing Elasticsearch connection, \
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
if self.client.indices.exists(index=index):
|
||||
logger.debug(
|
||||
f"Chat history index {index} already exists, skipping creation."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Creating index {index} for storing chat history.")
|
||||
|
||||
self.client.indices.create(
|
||||
index=index,
|
||||
mappings={
|
||||
"properties": {
|
||||
"session_id": {"type": "keyword"},
|
||||
"created_at": {"type": "date"},
|
||||
"history": {"type": "text"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain-py-ms/{__version__}"
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
*,
|
||||
es_url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> "Elasticsearch":
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
|
||||
if es_url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if es_url:
|
||||
connection_params["hosts"] = [es_url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
es_client = elasticsearch.Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||
)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return es_client
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
result = self.client.search(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
sort="created_at:asc",
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
if result and len(result["hits"]["hits"]) > 0:
|
||||
items = [
|
||||
json.loads(document["_source"]["history"])
|
||||
for document in result["hits"]["hits"]
|
||||
]
|
||||
else:
|
||||
items = []
|
||||
|
||||
return messages_from_dict(items)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the chat session in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.index(
|
||||
index=self.index,
|
||||
document={
|
||||
"session_id": self.session_id,
|
||||
"created_at": round(time() * 1000),
|
||||
"history": json.dumps(_message_to_dict(message)),
|
||||
},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.delete_by_query(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||
raise err
|
@@ -0,0 +1,34 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
elasticsearch:
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.9.0 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
||||
- xpack.security.http.ssl.enabled=false
|
||||
ports:
|
||||
- "9200:9200"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1",
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
||||
|
||||
kibana:
|
||||
image: docker.elastic.co/kibana/kibana:8.9.0
|
||||
environment:
|
||||
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
||||
ports:
|
||||
- "5601:5601"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:5601/login || exit 1",
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/memory/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_USERNAME
|
||||
- ES_PASSWORD
|
||||
"""
|
||||
|
||||
|
||||
class TestElasticsearch:
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||
# Run this integration test against Elasticsearch on localhost,
|
||||
# or an Elastic Cloud instance
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
es_username = os.environ.get("ES_USERNAME", "elastic")
|
||||
es_password = os.environ.get("ES_PASSWORD", "changeme")
|
||||
|
||||
if es_cloud_id:
|
||||
es = Elasticsearch(
|
||||
cloud_id=es_cloud_id,
|
||||
basic_auth=(es_username, es_password),
|
||||
)
|
||||
yield {
|
||||
"es_cloud_id": es_cloud_id,
|
||||
"es_user": es_username,
|
||||
"es_password": es_password,
|
||||
}
|
||||
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_memory_with_message_store(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Elasticsearch as a message store
|
||||
message_history = ElasticsearchChatMessageHistory(
|
||||
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
Reference in New Issue
Block a user