Add ElasticsearchChatMessageHistory (#10932)

**Description**

This PR adds the `ElasticsearchChatMessageHistory` implementation that
stores chat message history in the configured
[Elasticsearch](https://www.elastic.co/elasticsearch/) deployment.

```python
from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory

history = ElasticsearchChatMessageHistory(
    es_url="https://my-elasticsearch-deployment-url:9200", index="chat-history-index", session_id="123"
)

history.add_ai_message("This is me, the AI")
history.add_user_message("This is me, the human")
```

**Dependencies**
- [elasticsearch client](https://elasticsearch-py.readthedocs.io/)
required

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Adam Demjen
2023-10-12 19:51:38 -04:00
committed by GitHub
parent d3a5090e12
commit 008348ce71
6 changed files with 508 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ from langchain.memory.chat_message_histories import (
ChatMessageHistory,
CosmosDBChatMessageHistory,
DynamoDBChatMessageHistory,
ElasticsearchChatMessageHistory,
FileChatMessageHistory,
MomentoChatMessageHistory,
MongoDBChatMessageHistory,
@@ -77,6 +78,7 @@ __all__ = [
"ConversationTokenBufferMemory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
"InMemoryEntityStore",
"MomentoChatMessageHistory",

View File

@@ -3,6 +3,9 @@ from langchain.memory.chat_message_histories.cassandra import (
)
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
from langchain.memory.chat_message_histories.elasticsearch import (
ElasticsearchChatMessageHistory,
)
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
from langchain.memory.chat_message_histories.firestore import (
FirestoreChatMessageHistory,
@@ -25,6 +28,7 @@ __all__ = [
"CassandraChatMessageHistory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
"FileChatMessageHistory",
"FirestoreChatMessageHistory",
"MomentoChatMessageHistory",

View File

@@ -0,0 +1,191 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
logger = logging.getLogger(__name__)
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Elasticsearch.
Args:
es_url: URL of the Elasticsearch instance to connect to.
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
es_user: Username to use when connecting to Elasticsearch.
es_password: Password to use when connecting to Elasticsearch.
es_api_key: API key to use when connecting to Elasticsearch.
es_connection: Optional pre-existing Elasticsearch connection.
index: Name of the index to use.
session_id: Arbitrary key that is used to store the messages
of a single chat session.
"""
def __init__(
self,
index: str,
session_id: str,
*,
es_connection: Optional["Elasticsearch"] = None,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
):
self.index: str = index
self.session_id: str = session_id
# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:
self.client = es_connection.options(
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
)
if self.client.indices.exists(index=index):
logger.debug(
f"Chat history index {index} already exists, skipping creation."
)
else:
logger.debug(f"Creating index {index} for storing chat history.")
self.client.indices.create(
index=index,
mappings={
"properties": {
"session_id": {"type": "keyword"},
"created_at": {"type": "date"},
"history": {"type": "text"},
}
},
)
@staticmethod
def get_user_agent() -> str:
from langchain import __version__
return f"langchain-py-ms/{__version__}"
@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return es_client
@property
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""
try:
from elasticsearch import ApiError
result = self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err
if result and len(result["hits"]["hits"]) > 0:
items = [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
else:
items = []
return messages_from_dict(items)
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the chat session in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.index(
index=self.index,
document={
"session_id": self.session_id,
"created_at": round(time() * 1000),
"history": json.dumps(_message_to_dict(message)),
},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not add message to Elasticsearch: {err}")
raise err
def clear(self) -> None:
"""Clear session memory in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.delete_by_query(
index=self.index,
query={"term": {"session_id": self.session_id}},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
raise err

View File

@@ -0,0 +1,34 @@
version: "3"
services:
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.9.0 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
environment:
- discovery.type=single-node
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
- xpack.security.http.ssl.enabled=false
ports:
- "9200:9200"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1",
]
interval: 10s
retries: 60
kibana:
image: docker.elastic.co/kibana/kibana:8.9.0
environment:
- ELASTICSEARCH_URL=http://elasticsearch:9200
ports:
- "5601:5601"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:5601/login || exit 1",
]
interval: 10s
retries: 60

View File

@@ -0,0 +1,91 @@
import json
import os
import uuid
from typing import Generator, Union
import pytest
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory
from langchain.schema.messages import _message_to_dict
"""
cd tests/integration_tests/memory/docker-compose
docker-compose -f elasticsearch.yml up
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
- ES_CLOUD_ID
- ES_USERNAME
- ES_PASSWORD
"""
class TestElasticsearch:
@pytest.fixture(scope="class", autouse=True)
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
# Run this integration test against Elasticsearch on localhost,
# or an Elastic Cloud instance
from elasticsearch import Elasticsearch
es_url = os.environ.get("ES_URL", "http://localhost:9200")
es_cloud_id = os.environ.get("ES_CLOUD_ID")
es_username = os.environ.get("ES_USERNAME", "elastic")
es_password = os.environ.get("ES_PASSWORD", "changeme")
if es_cloud_id:
es = Elasticsearch(
cloud_id=es_cloud_id,
basic_auth=(es_username, es_password),
)
yield {
"es_cloud_id": es_cloud_id,
"es_user": es_username,
"es_password": es_password,
}
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=es_url)
yield {"es_url": es_url}
# Clear all indexes
index_names = es.indices.get(index="_all").keys()
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
@pytest.fixture(scope="function")
def index_name(self) -> str:
"""Return the index name."""
return f"test_{uuid.uuid4().hex}"
def test_memory_with_message_store(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test the memory with a message store."""
# setup Elasticsearch as a message store
message_history = ElasticsearchChatMessageHistory(
**elasticsearch_connection, index=index_name, session_id="test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Elasticsearch, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []