mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Add ElasticsearchChatMessageHistory (#10932)
**Description** This PR adds the `ElasticsearchChatMessageHistory` implementation that stores chat message history in the configured [Elasticsearch](https://www.elastic.co/elasticsearch/) deployment. ```python from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory history = ElasticsearchChatMessageHistory( es_url="https://my-elasticsearch-deployment-url:9200", index="chat-history-index", session_id="123" ) history.add_ai_message("This is me, the AI") history.add_user_message("This is me, the human") ``` **Dependencies** - [elasticsearch client](https://elasticsearch-py.readthedocs.io/) required Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d3a5090e12
commit
008348ce71
@ -0,0 +1,186 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "683953b3",
|
||||||
|
"metadata": {
|
||||||
|
"id": "683953b3"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Elasticsearch Chat Message History\n",
|
||||||
|
"\n",
|
||||||
|
">[Elasticsearch](https://www.elastic.co/elasticsearch/) is a distributed, RESTful search and analytics engine, capable of performing both vector and lexical search. It is built on top of the Apache Lucene library.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use chat message history functionality with Elasticsearch."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3c7720c3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set up Elasticsearch\n",
|
||||||
|
"\n",
|
||||||
|
"There are two main ways to set up an Elasticsearch instance:\n",
|
||||||
|
"\n",
|
||||||
|
"1. **Elastic Cloud.** Elastic Cloud is a managed Elasticsearch service. Sign up for a [free trial](https://cloud.elastic.co/registration?storm=langchain-notebook).\n",
|
||||||
|
"\n",
|
||||||
|
"2. **Local Elasticsearch installation.** Get started with Elasticsearch by running it locally. The easiest way is to use the official Elasticsearch Docker image. See the [Elasticsearch Docker documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html) for more information."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cdf1d2b7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Install dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e5bbffe2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install elasticsearch langchain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8be8fcc3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize Elasticsearch client and chat message history"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "8e2ee0fa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from langchain.memory import ElasticsearchChatMessageHistory\n",
|
||||||
|
"\n",
|
||||||
|
"es_url = os.environ.get(\"ES_URL\", \"http://localhost:9200\")\n",
|
||||||
|
"\n",
|
||||||
|
"# If using Elastic Cloud:\n",
|
||||||
|
"# es_cloud_id = os.environ.get(\"ES_CLOUD_ID\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Note: see Authentication section for various authentication methods\n",
|
||||||
|
"\n",
|
||||||
|
"history = ElasticsearchChatMessageHistory(\n",
|
||||||
|
" es_url=es_url,\n",
|
||||||
|
" index=\"test-history\",\n",
|
||||||
|
" session_id=\"test-session\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a63942e2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Use the chat message history"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "c1c7be79",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"indexing message content='hi!' additional_kwargs={} example=False\n",
|
||||||
|
"indexing message content='whats up?' additional_kwargs={} example=False\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"history.add_user_message(\"hi!\")\n",
|
||||||
|
"history.add_ai_message(\"whats up?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c46c216c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Authentication\n",
|
||||||
|
"\n",
|
||||||
|
"## Username/password\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"es_username = os.environ.get(\"ES_USERNAME\", \"elastic\")\n",
|
||||||
|
"es_password = os.environ.get(\"ES_PASSWORD\", \"changeme\")\n",
|
||||||
|
"\n",
|
||||||
|
"history = ElasticsearchChatMessageHistory(\n",
|
||||||
|
" es_url=es_url,\n",
|
||||||
|
" es_user=es_username,\n",
|
||||||
|
" es_password=es_password,\n",
|
||||||
|
" index=\"test-history\",\n",
|
||||||
|
" session_id=\"test-session\"\n",
|
||||||
|
")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"### How to obtain a password for the default \"elastic\" user\n",
|
||||||
|
"\n",
|
||||||
|
"To obtain your Elastic Cloud password for the default \"elastic\" user:\n",
|
||||||
|
"1. Log in to the Elastic Cloud console at https://cloud.elastic.co\n",
|
||||||
|
"2. Go to \"Security\" > \"Users\"\n",
|
||||||
|
"3. Locate the \"elastic\" user and click \"Edit\"\n",
|
||||||
|
"4. Click \"Reset password\"\n",
|
||||||
|
"5. Follow the prompts to reset the password\n",
|
||||||
|
"\n",
|
||||||
|
"## API key\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"es_api_key = os.environ.get(\"ES_API_KEY\")\n",
|
||||||
|
"\n",
|
||||||
|
"history = ElasticsearchChatMessageHistory(\n",
|
||||||
|
" es_api_key=es_api_key,\n",
|
||||||
|
" index=\"test-history\",\n",
|
||||||
|
" session_id=\"test-session\"\n",
|
||||||
|
")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"### How to obtain an API key\n",
|
||||||
|
"\n",
|
||||||
|
"To obtain an API key:\n",
|
||||||
|
"1. Log in to the Elastic Cloud console at https://cloud.elastic.co\n",
|
||||||
|
"2. Open Kibana and go to Stack Management > API Keys\n",
|
||||||
|
"3. Click \"Create API key\"\n",
|
||||||
|
"4. Enter a name for the API key and click \"Create\""
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -36,6 +36,7 @@ from langchain.memory.chat_message_histories import (
|
|||||||
ChatMessageHistory,
|
ChatMessageHistory,
|
||||||
CosmosDBChatMessageHistory,
|
CosmosDBChatMessageHistory,
|
||||||
DynamoDBChatMessageHistory,
|
DynamoDBChatMessageHistory,
|
||||||
|
ElasticsearchChatMessageHistory,
|
||||||
FileChatMessageHistory,
|
FileChatMessageHistory,
|
||||||
MomentoChatMessageHistory,
|
MomentoChatMessageHistory,
|
||||||
MongoDBChatMessageHistory,
|
MongoDBChatMessageHistory,
|
||||||
@ -77,6 +78,7 @@ __all__ = [
|
|||||||
"ConversationTokenBufferMemory",
|
"ConversationTokenBufferMemory",
|
||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"DynamoDBChatMessageHistory",
|
"DynamoDBChatMessageHistory",
|
||||||
|
"ElasticsearchChatMessageHistory",
|
||||||
"FileChatMessageHistory",
|
"FileChatMessageHistory",
|
||||||
"InMemoryEntityStore",
|
"InMemoryEntityStore",
|
||||||
"MomentoChatMessageHistory",
|
"MomentoChatMessageHistory",
|
||||||
|
@ -3,6 +3,9 @@ from langchain.memory.chat_message_histories.cassandra import (
|
|||||||
)
|
)
|
||||||
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
|
||||||
|
from langchain.memory.chat_message_histories.elasticsearch import (
|
||||||
|
ElasticsearchChatMessageHistory,
|
||||||
|
)
|
||||||
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
|
||||||
from langchain.memory.chat_message_histories.firestore import (
|
from langchain.memory.chat_message_histories.firestore import (
|
||||||
FirestoreChatMessageHistory,
|
FirestoreChatMessageHistory,
|
||||||
@ -25,6 +28,7 @@ __all__ = [
|
|||||||
"CassandraChatMessageHistory",
|
"CassandraChatMessageHistory",
|
||||||
"CosmosDBChatMessageHistory",
|
"CosmosDBChatMessageHistory",
|
||||||
"DynamoDBChatMessageHistory",
|
"DynamoDBChatMessageHistory",
|
||||||
|
"ElasticsearchChatMessageHistory",
|
||||||
"FileChatMessageHistory",
|
"FileChatMessageHistory",
|
||||||
"FirestoreChatMessageHistory",
|
"FirestoreChatMessageHistory",
|
||||||
"MomentoChatMessageHistory",
|
"MomentoChatMessageHistory",
|
||||||
|
@ -0,0 +1,191 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from time import time
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.schema import BaseChatMessageHistory
|
||||||
|
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history that stores history in Elasticsearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
es_url: URL of the Elasticsearch instance to connect to.
|
||||||
|
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||||
|
es_user: Username to use when connecting to Elasticsearch.
|
||||||
|
es_password: Password to use when connecting to Elasticsearch.
|
||||||
|
es_api_key: API key to use when connecting to Elasticsearch.
|
||||||
|
es_connection: Optional pre-existing Elasticsearch connection.
|
||||||
|
index: Name of the index to use.
|
||||||
|
session_id: Arbitrary key that is used to store the messages
|
||||||
|
of a single chat session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index: str,
|
||||||
|
session_id: str,
|
||||||
|
*,
|
||||||
|
es_connection: Optional["Elasticsearch"] = None,
|
||||||
|
es_url: Optional[str] = None,
|
||||||
|
es_cloud_id: Optional[str] = None,
|
||||||
|
es_user: Optional[str] = None,
|
||||||
|
es_api_key: Optional[str] = None,
|
||||||
|
es_password: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.index: str = index
|
||||||
|
self.session_id: str = session_id
|
||||||
|
|
||||||
|
# Initialize Elasticsearch client from passed client arg or connection info
|
||||||
|
if es_connection is not None:
|
||||||
|
self.client = es_connection.options(
|
||||||
|
headers={"user-agent": self.get_user_agent()}
|
||||||
|
)
|
||||||
|
elif es_url is not None or es_cloud_id is not None:
|
||||||
|
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||||
|
es_url=es_url,
|
||||||
|
username=es_user,
|
||||||
|
password=es_password,
|
||||||
|
cloud_id=es_cloud_id,
|
||||||
|
api_key=es_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"""Either provide a pre-existing Elasticsearch connection, \
|
||||||
|
or valid credentials for creating a new connection."""
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.client.indices.exists(index=index):
|
||||||
|
logger.debug(
|
||||||
|
f"Chat history index {index} already exists, skipping creation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Creating index {index} for storing chat history.")
|
||||||
|
|
||||||
|
self.client.indices.create(
|
||||||
|
index=index,
|
||||||
|
mappings={
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "keyword"},
|
||||||
|
"created_at": {"type": "date"},
|
||||||
|
"history": {"type": "text"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agent() -> str:
|
||||||
|
from langchain import __version__
|
||||||
|
|
||||||
|
return f"langchain-py-ms/{__version__}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def connect_to_elasticsearch(
|
||||||
|
*,
|
||||||
|
es_url: Optional[str] = None,
|
||||||
|
cloud_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> "Elasticsearch":
|
||||||
|
try:
|
||||||
|
import elasticsearch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import elasticsearch python package. "
|
||||||
|
"Please install it with `pip install elasticsearch`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if es_url and cloud_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Both es_url and cloud_id are defined. Please provide only one."
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if es_url:
|
||||||
|
connection_params["hosts"] = [es_url]
|
||||||
|
elif cloud_id:
|
||||||
|
connection_params["cloud_id"] = cloud_id
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
connection_params["api_key"] = api_key
|
||||||
|
elif username and password:
|
||||||
|
connection_params["basic_auth"] = (username, password)
|
||||||
|
|
||||||
|
es_client = elasticsearch.Elasticsearch(
|
||||||
|
**connection_params,
|
||||||
|
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
es_client.info()
|
||||||
|
except Exception as err:
|
||||||
|
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
return es_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||||
|
"""Retrieve the messages from Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
result = self.client.search(
|
||||||
|
index=self.index,
|
||||||
|
query={"term": {"session_id": self.session_id}},
|
||||||
|
sort="created_at:asc",
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
if result and len(result["hits"]["hits"]) > 0:
|
||||||
|
items = [
|
||||||
|
json.loads(document["_source"]["history"])
|
||||||
|
for document in result["hits"]["hits"]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
items = []
|
||||||
|
|
||||||
|
return messages_from_dict(items)
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a message to the chat session in Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
self.client.index(
|
||||||
|
index=self.index,
|
||||||
|
document={
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"created_at": round(time() * 1000),
|
||||||
|
"history": json.dumps(_message_to_dict(message)),
|
||||||
|
},
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory in Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
self.client.delete_by_query(
|
||||||
|
index=self.index,
|
||||||
|
query={"term": {"session_id": self.session_id}},
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||||
|
raise err
|
@ -0,0 +1,34 @@
|
|||||||
|
version: "3"
|
||||||
|
|
||||||
|
services:
|
||||||
|
elasticsearch:
|
||||||
|
image: docker.elastic.co/elasticsearch/elasticsearch:8.9.0 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
||||||
|
environment:
|
||||||
|
- discovery.type=single-node
|
||||||
|
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
||||||
|
- xpack.security.http.ssl.enabled=false
|
||||||
|
ports:
|
||||||
|
- "9200:9200"
|
||||||
|
healthcheck:
|
||||||
|
test:
|
||||||
|
[
|
||||||
|
"CMD-SHELL",
|
||||||
|
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1",
|
||||||
|
]
|
||||||
|
interval: 10s
|
||||||
|
retries: 60
|
||||||
|
|
||||||
|
kibana:
|
||||||
|
image: docker.elastic.co/kibana/kibana:8.9.0
|
||||||
|
environment:
|
||||||
|
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
||||||
|
ports:
|
||||||
|
- "5601:5601"
|
||||||
|
healthcheck:
|
||||||
|
test:
|
||||||
|
[
|
||||||
|
"CMD-SHELL",
|
||||||
|
"curl --silent --fail http://localhost:5601/login || exit 1",
|
||||||
|
]
|
||||||
|
interval: 10s
|
||||||
|
retries: 60
|
@ -0,0 +1,91 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Generator, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.memory.chat_message_histories import ElasticsearchChatMessageHistory
|
||||||
|
from langchain.schema.messages import _message_to_dict
|
||||||
|
|
||||||
|
"""
|
||||||
|
cd tests/integration_tests/memory/docker-compose
|
||||||
|
docker-compose -f elasticsearch.yml up
|
||||||
|
|
||||||
|
By default runs against local docker instance of Elasticsearch.
|
||||||
|
To run against Elastic Cloud, set the following environment variables:
|
||||||
|
- ES_CLOUD_ID
|
||||||
|
- ES_USERNAME
|
||||||
|
- ES_PASSWORD
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestElasticsearch:
|
||||||
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
|
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||||
|
# Run this integration test against Elasticsearch on localhost,
|
||||||
|
# or an Elastic Cloud instance
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||||
|
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||||
|
es_username = os.environ.get("ES_USERNAME", "elastic")
|
||||||
|
es_password = os.environ.get("ES_PASSWORD", "changeme")
|
||||||
|
|
||||||
|
if es_cloud_id:
|
||||||
|
es = Elasticsearch(
|
||||||
|
cloud_id=es_cloud_id,
|
||||||
|
basic_auth=(es_username, es_password),
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"es_cloud_id": es_cloud_id,
|
||||||
|
"es_user": es_username,
|
||||||
|
"es_password": es_password,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Running this integration test with local docker instance
|
||||||
|
es = Elasticsearch(hosts=es_url)
|
||||||
|
yield {"es_url": es_url}
|
||||||
|
|
||||||
|
# Clear all indexes
|
||||||
|
index_names = es.indices.get(index="_all").keys()
|
||||||
|
for index_name in index_names:
|
||||||
|
if index_name.startswith("test_"):
|
||||||
|
es.indices.delete(index=index_name)
|
||||||
|
es.indices.refresh(index="_all")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def index_name(self) -> str:
|
||||||
|
"""Return the index name."""
|
||||||
|
return f"test_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
def test_memory_with_message_store(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
# setup Elasticsearch as a message store
|
||||||
|
message_history = ElasticsearchChatMessageHistory(
|
||||||
|
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
# get the message history from the memory store and turn it into a json
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue
Block a user