mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 20:49:17 +00:00
elasticsearch: add ElasticsearchRetriever
(#18587)
Implement [Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/) interface for Elasticsearch. I opted to only expose the `body`, which gives you full flexibility, and none the other 68 arguments of the [search method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search). Added a user agent header for usage tracking in Elastic Cloud. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
8bc347c5fc
commit
ee7a7954b9
@ -2,6 +2,8 @@ from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core import __version__ as langchain_version
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
@ -17,6 +19,12 @@ class DistanceStrategy(str, Enum):
|
||||
COSINE = "COSINE"
|
||||
|
||||
|
||||
def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
|
||||
headers = dict(client._headers)
|
||||
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
|
||||
return client.options(headers=headers)
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
@ -10,6 +10,9 @@ from langchain_core.messages import (
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
from langchain_elasticsearch._utilities import with_user_agent_header
|
||||
from langchain_elasticsearch.client import create_elasticsearch_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
@ -51,23 +54,27 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
# Initialize Elasticsearch client from passed client arg or connection info
|
||||
if es_connection is not None:
|
||||
self.client = es_connection.options(
|
||||
headers={"user-agent": self.get_user_agent()}
|
||||
)
|
||||
self.client = es_connection
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
try:
|
||||
self.client = create_elasticsearch_client(
|
||||
url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Either provide a pre-existing Elasticsearch connection, \
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
self.client = with_user_agent_header(self.client, "langchain-py-ms")
|
||||
|
||||
if self.client.indices.exists(index=index):
|
||||
logger.debug(
|
||||
f"Chat history index {index} already exists, skipping creation."
|
||||
@ -86,60 +93,6 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_core import __version__
|
||||
|
||||
return f"langchain-py-ms/{__version__}"
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
*,
|
||||
es_url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> "Elasticsearch":
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
|
||||
if es_url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if es_url:
|
||||
connection_params["hosts"] = [es_url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
es_client = elasticsearch.Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||
)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return es_client
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Elasticsearch"""
|
||||
|
@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
def create_elasticsearch_client(
|
||||
url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Elasticsearch:
|
||||
if url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if url:
|
||||
connection_params["hosts"] = [url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
if params is not None:
|
||||
connection_params.update(params)
|
||||
|
||||
es_client = Elasticsearch(**connection_params)
|
||||
|
||||
es_client.info() # test connection
|
||||
|
||||
return es_client
|
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_elasticsearch._utilities import with_user_agent_header
|
||||
from langchain_elasticsearch.client import create_elasticsearch_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchRetriever(BaseRetriever):
|
||||
"""
|
||||
Elasticsearch retriever
|
||||
|
||||
Args:
|
||||
es_client: Elasticsearch client connection. Alternatively you can use the
|
||||
`from_es_params` method with parameters to initialize the client.
|
||||
index_name: The name of the index to query.
|
||||
body_func: Function to create an Elasticsearch DSL query body from a search
|
||||
string. All parameters (including for example the `size` parameter to limit
|
||||
the number of results) must also be set in the body.
|
||||
content_field: The document field name that contains the page content.
|
||||
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
|
||||
"""
|
||||
|
||||
es_client: Elasticsearch
|
||||
index_name: str
|
||||
body_func: Callable[[str], Dict]
|
||||
content_field: Optional[str] = None
|
||||
document_mapper: Optional[Callable[[Dict], Document]] = None
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.content_field is None and self.document_mapper is None:
|
||||
raise ValueError("One of content_field or document_mapper must be defined.")
|
||||
if self.content_field is not None and self.document_mapper is not None:
|
||||
raise ValueError(
|
||||
"Both content_field and document_mapper are defined. "
|
||||
"Please provide only one."
|
||||
)
|
||||
|
||||
self.document_mapper = self.document_mapper or self._field_mapper
|
||||
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
|
||||
|
||||
@staticmethod
|
||||
def from_es_params(
|
||||
index_name: str,
|
||||
body_func: Callable[[str], Dict],
|
||||
content_field: Optional[str] = None,
|
||||
document_mapper: Optional[Callable[[Dict], Document]] = None,
|
||||
url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> "ElasticsearchRetriever":
|
||||
client = None
|
||||
try:
|
||||
client = create_elasticsearch_client(
|
||||
url=url,
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
username=username,
|
||||
password=password,
|
||||
params=params,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return ElasticsearchRetriever(
|
||||
es_client=client,
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=content_field,
|
||||
document_mapper=document_mapper,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if not self.es_client or not self.document_mapper:
|
||||
raise ValueError("faulty configuration") # should not happen
|
||||
|
||||
body = self.body_func(query)
|
||||
results = self.es_client.search(index=self.index_name, body=body)
|
||||
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]
|
||||
|
||||
def _field_mapper(self, hit: Dict[str, Any]) -> Document:
|
||||
content = hit["_source"].pop(self.content_field)
|
||||
return Document(page_content=content, metadata=hit)
|
@ -23,6 +23,7 @@ from langchain_core.vectorstores import VectorStore
|
||||
from langchain_elasticsearch._utilities import (
|
||||
DistanceStrategy,
|
||||
maximal_marginal_relevance,
|
||||
with_user_agent_header,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -526,9 +527,7 @@ class ElasticsearchStore(VectorStore):
|
||||
self.strategy = strategy
|
||||
|
||||
if es_connection is not None:
|
||||
headers = dict(es_connection._headers)
|
||||
headers.update({"user-agent": self.get_user_agent()})
|
||||
self.client = es_connection.options(headers=headers)
|
||||
self.client = es_connection
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchStore.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
@ -544,11 +543,7 @@ class ElasticsearchStore(VectorStore):
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_core import __version__
|
||||
|
||||
return f"langchain-py-vs/{__version__}"
|
||||
self.client = with_user_agent_header(self.client, "langchain-py-vs")
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
@ -582,10 +577,7 @@ class ElasticsearchStore(VectorStore):
|
||||
if es_params is not None:
|
||||
connection_params.update(es_params)
|
||||
|
||||
es_client = Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchStore.get_user_agent()},
|
||||
)
|
||||
es_client = Elasticsearch(**connection_params)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as e:
|
||||
|
133
libs/partners/elasticsearch/poetry.lock
generated
133
libs/partners/elasticsearch/poetry.lock
generated
@ -599,7 +599,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.1.10"
|
||||
version = "0.1.11"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -612,9 +612,9 @@ async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
|
||||
dataclasses-json = ">= 0.5.7, < 0.7"
|
||||
jsonpatch = "^1.33"
|
||||
langchain-community = ">=0.0.25,<0.1"
|
||||
langchain-core = ">=0.1.28,<0.2"
|
||||
langchain-core = ">=0.1.29,<0.2"
|
||||
langchain-text-splitters = ">=0.0.1,<0.1"
|
||||
langsmith = "^0.1.14"
|
||||
langsmith = "^0.1.17"
|
||||
numpy = "^1"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
@ -671,7 +671,7 @@ url = "../../community"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.28"
|
||||
version = "0.1.29"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -716,13 +716,13 @@ url = "../../text-splitters"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.14"
|
||||
version = "0.1.21"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langsmith-0.1.14-py3-none-any.whl", hash = "sha256:ecb243057d2a43c2da0524fe395585bc3421bb5d24f1cdd53eb06fbe63e43a69"},
|
||||
{file = "langsmith-0.1.14.tar.gz", hash = "sha256:b95f267d25681f4c9862bb68236fba8a57a60ec7921ecfdaa125936807e51bde"},
|
||||
{file = "langsmith-0.1.21-py3-none-any.whl", hash = "sha256:ac3d455d9651879ed306500a0504a2b9b9909225ab178e2446a8bace75e65e23"},
|
||||
{file = "langsmith-0.1.21.tar.gz", hash = "sha256:eef6b8a0d3bec7fcfc69ac5b35a16365ffac025dab0c1a4d77d6a7f7d3bbd3de"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -732,13 +732,13 @@ requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow"
|
||||
version = "3.21.0"
|
||||
version = "3.21.1"
|
||||
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"},
|
||||
{file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"},
|
||||
{file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"},
|
||||
{file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1033,13 +1033,13 @@ testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.6.2"
|
||||
version = "2.6.3"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"},
|
||||
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"},
|
||||
{file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
|
||||
{file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1215,13 +1215,13 @@ watchdog = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
version = "2.9.0.post0"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
files = [
|
||||
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
||||
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
||||
{file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
|
||||
{file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1252,6 +1252,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -1357,60 +1358,60 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "sqlalchemy"
|
||||
version = "2.0.27"
|
||||
version = "2.0.28"
|
||||
description = "Database Abstraction Library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d04e579e911562f1055d26dab1868d3e0bb905db3bccf664ee8ad109f035618a"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fa67d821c1fd268a5a87922ef4940442513b4e6c377553506b9db3b83beebbd8"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c7a596d0be71b7baa037f4ac10d5e057d276f65a9a611c46970f012752ebf2d"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:954d9735ee9c3fa74874c830d089a815b7b48df6f6b6e357a74130e478dbd951"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5cd20f58c29bbf2680039ff9f569fa6d21453fbd2fa84dbdb4092f006424c2e6"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:03f448ffb731b48323bda68bcc93152f751436ad6037f18a42b7e16af9e91c07"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-win32.whl", hash = "sha256:d997c5938a08b5e172c30583ba6b8aad657ed9901fc24caf3a7152eeccb2f1b4"},
|
||||
{file = "SQLAlchemy-2.0.27-cp310-cp310-win_amd64.whl", hash = "sha256:eb15ef40b833f5b2f19eeae65d65e191f039e71790dd565c2af2a3783f72262f"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c5bad7c60a392850d2f0fee8f355953abaec878c483dd7c3836e0089f046bf6"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3012ab65ea42de1be81fff5fb28d6db893ef978950afc8130ba707179b4284a"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbcd77c4d94b23e0753c5ed8deba8c69f331d4fd83f68bfc9db58bc8983f49cd"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d177b7e82f6dd5e1aebd24d9c3297c70ce09cd1d5d37b43e53f39514379c029c"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:680b9a36029b30cf063698755d277885d4a0eab70a2c7c6e71aab601323cba45"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1306102f6d9e625cebaca3d4c9c8f10588735ef877f0360b5cdb4fdfd3fd7131"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-win32.whl", hash = "sha256:5b78aa9f4f68212248aaf8943d84c0ff0f74efc65a661c2fc68b82d498311fd5"},
|
||||
{file = "SQLAlchemy-2.0.27-cp311-cp311-win_amd64.whl", hash = "sha256:15e19a84b84528f52a68143439d0c7a3a69befcd4f50b8ef9b7b69d2628ae7c4"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0de1263aac858f288a80b2071990f02082c51d88335a1db0d589237a3435fe71"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce850db091bf7d2a1f2fdb615220b968aeff3849007b1204bf6e3e50a57b3d32"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dfc936870507da96aebb43e664ae3a71a7b96278382bcfe84d277b88e379b18"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4fbe6a766301f2e8a4519f4500fe74ef0a8509a59e07a4085458f26228cd7cc"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4535c49d961fe9a77392e3a630a626af5baa967172d42732b7a43496c8b28876"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0fb3bffc0ced37e5aa4ac2416f56d6d858f46d4da70c09bb731a246e70bff4d5"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-win32.whl", hash = "sha256:7f470327d06400a0aa7926b375b8e8c3c31d335e0884f509fe272b3c700a7254"},
|
||||
{file = "SQLAlchemy-2.0.27-cp312-cp312-win_amd64.whl", hash = "sha256:f9374e270e2553653d710ece397df67db9d19c60d2647bcd35bfc616f1622dcd"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e97cf143d74a7a5a0f143aa34039b4fecf11343eed66538610debc438685db4a"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7b5a3e2120982b8b6bd1d5d99e3025339f7fb8b8267551c679afb39e9c7c7f1"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36aa62b765cf9f43a003233a8c2d7ffdeb55bc62eaa0a0380475b228663a38f"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5ada0438f5b74c3952d916c199367c29ee4d6858edff18eab783b3978d0db16d"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b1d9d1bfd96eef3c3faedb73f486c89e44e64e40e5bfec304ee163de01cf996f"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win32.whl", hash = "sha256:ca891af9f3289d24a490a5fde664ea04fe2f4984cd97e26de7442a4251bd4b7c"},
|
||||
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win_amd64.whl", hash = "sha256:fd8aafda7cdff03b905d4426b714601c0978725a19efc39f5f207b86d188ba01"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec1f5a328464daf7a1e4e385e4f5652dd9b1d12405075ccba1df842f7774b4fc"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad862295ad3f644e3c2c0d8b10a988e1600d3123ecb48702d2c0f26771f1c396"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48217be1de7d29a5600b5c513f3f7664b21d32e596d69582be0a94e36b8309cb"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e56afce6431450442f3ab5973156289bd5ec33dd618941283847c9fd5ff06bf"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:611068511b5531304137bcd7fe8117c985d1b828eb86043bd944cebb7fae3910"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b86abba762ecfeea359112b2bb4490802b340850bbee1948f785141a5e020de8"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-win32.whl", hash = "sha256:30d81cc1192dc693d49d5671cd40cdec596b885b0ce3b72f323888ab1c3863d5"},
|
||||
{file = "SQLAlchemy-2.0.27-cp38-cp38-win_amd64.whl", hash = "sha256:120af1e49d614d2525ac247f6123841589b029c318b9afbfc9e2b70e22e1827d"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d07ee7793f2aeb9b80ec8ceb96bc8cc08a2aec8a1b152da1955d64e4825fcbac"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb0845e934647232b6ff5150df37ceffd0b67b754b9fdbb095233deebcddbd4a"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fc19ae2e07a067663dd24fca55f8ed06a288384f0e6e3910420bf4b1270cc51"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b90053be91973a6fb6020a6e44382c97739736a5a9d74e08cc29b196639eb979"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2f5c9dfb0b9ab5e3a8a00249534bdd838d943ec4cfb9abe176a6c33408430230"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33e8bde8fff203de50399b9039c4e14e42d4d227759155c21f8da4a47fc8053c"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-win32.whl", hash = "sha256:d873c21b356bfaf1589b89090a4011e6532582b3a8ea568a00e0c3aab09399dd"},
|
||||
{file = "SQLAlchemy-2.0.27-cp39-cp39-win_amd64.whl", hash = "sha256:ff2f1b7c963961d41403b650842dc2039175b906ab2093635d8319bef0b7d620"},
|
||||
{file = "SQLAlchemy-2.0.27-py3-none-any.whl", hash = "sha256:1ab4e0448018d01b142c916cc7119ca573803a4745cfe341b8f95657812700ac"},
|
||||
{file = "SQLAlchemy-2.0.27.tar.gz", hash = "sha256:86a6ed69a71fe6b88bf9331594fa390a2adda4a49b5c06f98e47bf0d392534f8"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0b148ab0438f72ad21cb004ce3bdaafd28465c4276af66df3b9ecd2037bf252"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbda76961eb8f27e6ad3c84d1dc56d5bc61ba8f02bd20fcf3450bd421c2fcc9c"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feea693c452d85ea0015ebe3bb9cd15b6f49acc1a31c28b3c50f4db0f8fb1e71"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da98815f82dce0cb31fd1e873a0cb30934971d15b74e0d78cf21f9e1b05953f"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a5adf383c73f2d49ad15ff363a8748319ff84c371eed59ffd0127355d6ea1da"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56856b871146bfead25fbcaed098269d90b744eea5cb32a952df00d542cdd368"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-win32.whl", hash = "sha256:943aa74a11f5806ab68278284a4ddd282d3fb348a0e96db9b42cb81bf731acdc"},
|
||||
{file = "SQLAlchemy-2.0.28-cp310-cp310-win_amd64.whl", hash = "sha256:c6c4da4843e0dabde41b8f2e8147438330924114f541949e6318358a56d1875a"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46a3d4e7a472bfff2d28db838669fc437964e8af8df8ee1e4548e92710929adc"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0d3dd67b5d69794cfe82862c002512683b3db038b99002171f624712fa71aeaa"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61e2e41656a673b777e2f0cbbe545323dbe0d32312f590b1bc09da1de6c2a02"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0315d9125a38026227f559488fe7f7cee1bd2fbc19f9fd637739dc50bb6380b2"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af8ce2d31679006e7b747d30a89cd3ac1ec304c3d4c20973f0f4ad58e2d1c4c9"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:81ba314a08c7ab701e621b7ad079c0c933c58cdef88593c59b90b996e8b58fa5"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-win32.whl", hash = "sha256:1ee8bd6d68578e517943f5ebff3afbd93fc65f7ef8f23becab9fa8fb315afb1d"},
|
||||
{file = "SQLAlchemy-2.0.28-cp311-cp311-win_amd64.whl", hash = "sha256:ad7acbe95bac70e4e687a4dc9ae3f7a2f467aa6597049eeb6d4a662ecd990bb6"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d3499008ddec83127ab286c6f6ec82a34f39c9817f020f75eca96155f9765097"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b66fcd38659cab5d29e8de5409cdf91e9986817703e1078b2fdaad731ea66f5"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea30da1e76cb1acc5b72e204a920a3a7678d9d52f688f087dc08e54e2754c67"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:124202b4e0edea7f08a4db8c81cc7859012f90a0d14ba2bf07c099aff6e96462"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e23b88c69497a6322b5796c0781400692eca1ae5532821b39ce81a48c395aae9"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b6303bfd78fb3221847723104d152e5972c22367ff66edf09120fcde5ddc2e2"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-win32.whl", hash = "sha256:a921002be69ac3ab2cf0c3017c4e6a3377f800f1fca7f254c13b5f1a2f10022c"},
|
||||
{file = "SQLAlchemy-2.0.28-cp312-cp312-win_amd64.whl", hash = "sha256:b4a2cf92995635b64876dc141af0ef089c6eea7e05898d8d8865e71a326c0385"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e91b5e341f8c7f1e5020db8e5602f3ed045a29f8e27f7f565e0bdee3338f2c7"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c7b78dfc7278329f27be02c44abc0d69fe235495bb8e16ec7ef1b1a17952db"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eba73ef2c30695cb7eabcdb33bb3d0b878595737479e152468f3ba97a9c22a4"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5df5d1dafb8eee89384fb7a1f79128118bc0ba50ce0db27a40750f6f91aa99d5"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2858bbab1681ee5406650202950dc8f00e83b06a198741b7c656e63818633526"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-win32.whl", hash = "sha256:9461802f2e965de5cff80c5a13bc945abea7edaa1d29360b485c3d2b56cdb075"},
|
||||
{file = "SQLAlchemy-2.0.28-cp37-cp37m-win_amd64.whl", hash = "sha256:a6bec1c010a6d65b3ed88c863d56b9ea5eeefdf62b5e39cafd08c65f5ce5198b"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:843a882cadebecc655a68bd9a5b8aa39b3c52f4a9a5572a3036fb1bb2ccdc197"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dbb990612c36163c6072723523d2be7c3eb1517bbdd63fe50449f56afafd1133"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7e4baf9161d076b9a7e432fce06217b9bd90cfb8f1d543d6e8c4595627edb9"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a5354cb4de9b64bccb6ea33162cb83e03dbefa0d892db88a672f5aad638a75"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fffcc8edc508801ed2e6a4e7b0d150a62196fd28b4e16ab9f65192e8186102b6"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aca7b6d99a4541b2ebab4494f6c8c2f947e0df4ac859ced575238e1d6ca5716b"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-win32.whl", hash = "sha256:8c7f10720fc34d14abad5b647bc8202202f4948498927d9f1b4df0fb1cf391b7"},
|
||||
{file = "SQLAlchemy-2.0.28-cp38-cp38-win_amd64.whl", hash = "sha256:243feb6882b06a2af68ecf4bec8813d99452a1b62ba2be917ce6283852cf701b"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc4974d3684f28b61b9a90fcb4c41fb340fd4b6a50c04365704a4da5a9603b05"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:87724e7ed2a936fdda2c05dbd99d395c91ea3c96f029a033a4a20e008dd876bf"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68722e6a550f5de2e3cfe9da6afb9a7dd15ef7032afa5651b0f0c6b3adb8815d"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:328529f7c7f90adcd65aed06a161851f83f475c2f664a898af574893f55d9e53"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:df40c16a7e8be7413b885c9bf900d402918cc848be08a59b022478804ea076b8"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:426f2fa71331a64f5132369ede5171c52fd1df1bd9727ce621f38b5b24f48750"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-win32.whl", hash = "sha256:33157920b233bc542ce497a81a2e1452e685a11834c5763933b440fedd1d8e2d"},
|
||||
{file = "SQLAlchemy-2.0.28-cp39-cp39-win_amd64.whl", hash = "sha256:2f60843068e432311c886c5f03c4664acaef507cf716f6c60d5fde7265be9d7b"},
|
||||
{file = "SQLAlchemy-2.0.28-py3-none-any.whl", hash = "sha256:78bb7e8da0183a8301352d569900d9d3594c48ac21dc1c2ec6b3121ed8b6c986"},
|
||||
{file = "SQLAlchemy-2.0.28.tar.gz", hash = "sha256:dd53b6c4e6d960600fd6532b79ee28e2da489322fcf6648738134587faf767b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from elastic_transport import Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
|
||||
def clear_test_indices(es: Elasticsearch) -> None:
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
|
||||
def requests_saving_es_client() -> Elasticsearch:
|
||||
class CustomTransport(Transport):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.requests: List[Dict] = []
|
||||
|
||||
def perform_request(self, *args, **kwargs): # type: ignore
|
||||
self.requests.append(kwargs)
|
||||
return super().perform_request(*args, **kwargs)
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if cloud_id:
|
||||
# Running this integration test with Elastic Cloud
|
||||
# Required for in-stack inference testing (ELSER + model_id)
|
||||
es = Elasticsearch(
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
transport_class=CustomTransport,
|
||||
)
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport)
|
||||
|
||||
return es
|
@ -0,0 +1,35 @@
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
elasticsearch:
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
|
||||
environment:
|
||||
- discovery.type=single-node
|
||||
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
|
||||
- xpack.security.http.ssl.enabled=false
|
||||
- xpack.license.self_generated.type=trial
|
||||
ports:
|
||||
- "9200:9200"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1"
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
||||
|
||||
kibana:
|
||||
image: docker.elastic.co/kibana/kibana:8.12.1
|
||||
environment:
|
||||
- ELASTICSEARCH_URL=http://elasticsearch:9200
|
||||
ports:
|
||||
- "5601:5601"
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"curl --silent --fail http://localhost:5601/login || exit 1"
|
||||
]
|
||||
interval: 10s
|
||||
retries: 60
|
@ -10,8 +10,8 @@ from langchain_core.messages import message_to_dict
|
||||
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/memory/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
cd tests/integration_tests
|
||||
docker-compose up elasticsearch
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
|
@ -0,0 +1,169 @@
|
||||
"""Test ElasticsearchRetriever functionality."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
|
||||
|
||||
from ._test_utilities import requests_saving_es_client
|
||||
|
||||
"""
|
||||
cd tests/integration_tests
|
||||
docker-compose up elasticsearch
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_API_KEY
|
||||
"""
|
||||
|
||||
|
||||
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
|
||||
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
|
||||
for identifier, text in docs:
|
||||
es_client.index(
|
||||
index=index_name,
|
||||
document={field_name: text, "another_field": 1},
|
||||
id=str(identifier),
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
|
||||
class TestElasticsearchRetriever:
|
||||
@pytest.fixture(scope="function")
|
||||
def es_client(self) -> Any:
|
||||
return requests_saving_es_client()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test that the user agent header is set correctly."""
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=lambda _: {"query": {"match_all": {}}},
|
||||
content_field="text",
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
assert retriever.es_client
|
||||
user_agent = retriever.es_client._headers["User-Agent"]
|
||||
assert (
|
||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
index_test_data(es_client, index_name, "text")
|
||||
retriever.get_relevant_documents("foo")
|
||||
|
||||
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
|
||||
user_agent = search_request["headers"]["User-Agent"]
|
||||
assert (
|
||||
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_init_url(self, index_name: str) -> None:
|
||||
"""Test end-to-end indexing and search."""
|
||||
|
||||
text_field = "text"
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
retriever = ElasticsearchRetriever.from_es_params(
|
||||
url="http://localhost:9200",
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=text_field,
|
||||
)
|
||||
|
||||
index_test_data(retriever.es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
||||
for r in result:
|
||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
||||
assert text_field not in r.metadata["_source"]
|
||||
assert "another_field" in r.metadata["_source"]
|
||||
|
||||
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test end-to-end indexing and search."""
|
||||
|
||||
text_field = "text"
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
content_field=text_field,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
index_test_data(es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
|
||||
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
|
||||
for r in result:
|
||||
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
|
||||
assert text_field not in r.metadata["_source"]
|
||||
assert "another_field" in r.metadata["_source"]
|
||||
|
||||
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
|
||||
"""Test custom document maper"""
|
||||
|
||||
text_field = "text"
|
||||
meta = {"some_field": 12}
|
||||
|
||||
def body_func(query: str) -> Dict:
|
||||
return {"query": {"match": {text_field: {"query": query}}}}
|
||||
|
||||
def id_as_content(hit: Dict) -> Document:
|
||||
return Document(page_content=hit["_id"], metadata=meta)
|
||||
|
||||
retriever = ElasticsearchRetriever(
|
||||
index_name=index_name,
|
||||
body_func=body_func,
|
||||
document_mapper=id_as_content,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
index_test_data(es_client, index_name, text_field)
|
||||
result = retriever.get_relevant_documents("foo")
|
||||
|
||||
assert [r.page_content for r in result] == ["3", "1", "5"]
|
||||
assert [r.metadata for r in result] == [meta, meta, meta]
|
||||
|
||||
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
|
||||
"""Raise exception if both content_field and document_mapper are specified."""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ElasticsearchRetriever(
|
||||
content_field="text",
|
||||
document_mapper=lambda x: x,
|
||||
index_name="foo",
|
||||
body_func=lambda x: x,
|
||||
es_client=es_client,
|
||||
)
|
||||
|
||||
def test_fail_neither_content_field_nor_mapper(
|
||||
self, es_client: Elasticsearch
|
||||
) -> None:
|
||||
"""Raise exception if neither content_field nor document_mapper are specified"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ElasticsearchRetriever(
|
||||
index_name="foo",
|
||||
body_func=lambda x: x,
|
||||
es_client=es_client,
|
||||
)
|
@ -7,7 +7,6 @@ import uuid
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
import pytest
|
||||
from elastic_transport import Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from langchain_core.documents import Document
|
||||
@ -18,12 +17,13 @@ from ..fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from ._test_utilities import clear_test_indices, requests_saving_es_client
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/vectorstores/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
cd tests/integration_tests
|
||||
docker-compose up elasticsearch
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
@ -74,12 +74,8 @@ class TestElasticsearch:
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
# clear indices
|
||||
clear_test_indices(es)
|
||||
|
||||
# clear all test pipelines
|
||||
try:
|
||||
@ -94,32 +90,11 @@ class TestElasticsearch:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def es_client(self) -> Any:
|
||||
class CustomTransport(Transport):
|
||||
requests = []
|
||||
|
||||
def perform_request(self, *args, **kwargs): # type: ignore
|
||||
self.requests.append(kwargs)
|
||||
return super().perform_request(*args, **kwargs)
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if cloud_id:
|
||||
# Running this integration test with Elastic Cloud
|
||||
# Required for in-stack inference testing (ELSER + model_id)
|
||||
es = Elasticsearch(
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
transport_class=CustomTransport,
|
||||
)
|
||||
return es
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url, transport_class=CustomTransport)
|
||||
return es
|
||||
return requests_saving_es_client()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
@ -887,11 +862,8 @@ class TestElasticsearch:
|
||||
)
|
||||
|
||||
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
|
||||
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||
match = re.match(pattern, user_agent)
|
||||
|
||||
assert (
|
||||
match is not None
|
||||
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_elasticsearch_with_internal_user_agent(
|
||||
@ -908,15 +880,12 @@ class TestElasticsearch:
|
||||
)
|
||||
|
||||
user_agent = store.client._headers["User-Agent"]
|
||||
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||
match = re.match(pattern, user_agent)
|
||||
|
||||
assert (
|
||||
match is not None
|
||||
re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_bulk_args(self, es_client: Any, index_name: str) -> None:
|
||||
"""Test to make sure the user-agent is set correctly."""
|
||||
"""Test to make sure the bulk arguments work as expected."""
|
||||
|
||||
texts = ["foo", "bob", "baz"]
|
||||
ElasticsearchStore.from_texts(
|
||||
|
Loading…
Reference in New Issue
Block a user