From ee7a7954b9f4f2853da32080e20aff2fe26b1795 Mon Sep 17 00:00:00 2001 From: Max Jakob Date: Wed, 6 Mar 2024 01:42:50 +0100 Subject: [PATCH] elasticsearch: add `ElasticsearchRetriever` (#18587) Implement [Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/) interface for Elasticsearch. I opted to only expose the `body`, which gives you full flexibility, and none the other 68 arguments of the [search method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search). Added a user agent header for usage tracking in Elastic Cloud. --------- Co-authored-by: Erick Friis --- .../langchain_elasticsearch/_utilities.py | 8 + .../langchain_elasticsearch/chat_history.py | 83 ++------- .../langchain_elasticsearch/client.py | 40 +++++ .../langchain_elasticsearch/retrievers.py | 97 ++++++++++ .../langchain_elasticsearch/vectorstores.py | 16 +- libs/partners/elasticsearch/poetry.lock | 133 +++++++------- .../integration_tests/_test_utilities.py | 42 +++++ .../integration_tests/docker-compose.yml | 35 ++++ .../integration_tests/test_chat_history.py | 4 +- .../integration_tests/test_retrievers.py | 169 ++++++++++++++++++ .../integration_tests/test_vectorstores.py | 53 ++---- 11 files changed, 493 insertions(+), 187 deletions(-) create mode 100644 libs/partners/elasticsearch/langchain_elasticsearch/client.py create mode 100644 libs/partners/elasticsearch/langchain_elasticsearch/retrievers.py create mode 100644 libs/partners/elasticsearch/tests/integration_tests/_test_utilities.py create mode 100644 libs/partners/elasticsearch/tests/integration_tests/docker-compose.yml create mode 100644 libs/partners/elasticsearch/tests/integration_tests/test_retrievers.py diff --git a/libs/partners/elasticsearch/langchain_elasticsearch/_utilities.py b/libs/partners/elasticsearch/langchain_elasticsearch/_utilities.py index 237d587bc03..101cff425dc 100644 --- a/libs/partners/elasticsearch/langchain_elasticsearch/_utilities.py +++ b/libs/partners/elasticsearch/langchain_elasticsearch/_utilities.py @@ -2,6 +2,8 @@ from enum import Enum from typing import List, Union import numpy as np +from elasticsearch import Elasticsearch +from langchain_core import __version__ as langchain_version Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] @@ -17,6 +19,12 @@ class DistanceStrategy(str, Enum): COSINE = "COSINE" +def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch: + headers = dict(client._headers) + headers.update({"user-agent": f"{header_prefix}/{langchain_version}"}) + return client.options(headers=headers) + + def maximal_marginal_relevance( query_embedding: np.ndarray, embedding_list: list, diff --git a/libs/partners/elasticsearch/langchain_elasticsearch/chat_history.py b/libs/partners/elasticsearch/langchain_elasticsearch/chat_history.py index 026557c33fb..5d75beae503 100644 --- a/libs/partners/elasticsearch/langchain_elasticsearch/chat_history.py +++ b/libs/partners/elasticsearch/langchain_elasticsearch/chat_history.py @@ -1,7 +1,7 @@ import json import logging from time import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, List, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( @@ -10,6 +10,9 @@ from langchain_core.messages import ( messages_from_dict, ) +from langchain_elasticsearch._utilities import with_user_agent_header +from langchain_elasticsearch.client import create_elasticsearch_client + if TYPE_CHECKING: from elasticsearch import Elasticsearch @@ -51,23 +54,27 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory): # Initialize Elasticsearch client from passed client arg or connection info if es_connection is not None: - self.client = es_connection.options( - headers={"user-agent": self.get_user_agent()} - ) + self.client = es_connection elif es_url is not None or es_cloud_id is not None: - self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch( - es_url=es_url, - username=es_user, - password=es_password, - cloud_id=es_cloud_id, - api_key=es_api_key, - ) + try: + self.client = create_elasticsearch_client( + url=es_url, + username=es_user, + password=es_password, + cloud_id=es_cloud_id, + api_key=es_api_key, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err else: raise ValueError( """Either provide a pre-existing Elasticsearch connection, \ or valid credentials for creating a new connection.""" ) + self.client = with_user_agent_header(self.client, "langchain-py-ms") + if self.client.indices.exists(index=index): logger.debug( f"Chat history index {index} already exists, skipping creation." @@ -86,60 +93,6 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory): }, ) - @staticmethod - def get_user_agent() -> str: - from langchain_core import __version__ - - return f"langchain-py-ms/{__version__}" - - @staticmethod - def connect_to_elasticsearch( - *, - es_url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - ) -> "Elasticsearch": - try: - import elasticsearch - except ImportError: - raise ImportError( - "Could not import elasticsearch python package. " - "Please install it with `pip install elasticsearch`." - ) - - if es_url and cloud_id: - raise ValueError( - "Both es_url and cloud_id are defined. Please provide only one." - ) - - connection_params: Dict[str, Any] = {} - - if es_url: - connection_params["hosts"] = [es_url] - elif cloud_id: - connection_params["cloud_id"] = cloud_id - else: - raise ValueError("Please provide either elasticsearch_url or cloud_id.") - - if api_key: - connection_params["api_key"] = api_key - elif username and password: - connection_params["basic_auth"] = (username, password) - - es_client = elasticsearch.Elasticsearch( - **connection_params, - headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()}, - ) - try: - es_client.info() - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err - - return es_client - @property def messages(self) -> List[BaseMessage]: # type: ignore[override] """Retrieve the messages from Elasticsearch""" diff --git a/libs/partners/elasticsearch/langchain_elasticsearch/client.py b/libs/partners/elasticsearch/langchain_elasticsearch/client.py new file mode 100644 index 00000000000..3e4b5460819 --- /dev/null +++ b/libs/partners/elasticsearch/langchain_elasticsearch/client.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, Optional + +from elasticsearch import Elasticsearch + + +def create_elasticsearch_client( + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, +) -> Elasticsearch: + if url and cloud_id: + raise ValueError( + "Both es_url and cloud_id are defined. Please provide only one." + ) + + connection_params: Dict[str, Any] = {} + + if url: + connection_params["hosts"] = [url] + elif cloud_id: + connection_params["cloud_id"] = cloud_id + else: + raise ValueError("Please provide either elasticsearch_url or cloud_id.") + + if api_key: + connection_params["api_key"] = api_key + elif username and password: + connection_params["basic_auth"] = (username, password) + + if params is not None: + connection_params.update(params) + + es_client = Elasticsearch(**connection_params) + + es_client.info() # test connection + + return es_client diff --git a/libs/partners/elasticsearch/langchain_elasticsearch/retrievers.py b/libs/partners/elasticsearch/langchain_elasticsearch/retrievers.py new file mode 100644 index 00000000000..e2375a83c90 --- /dev/null +++ b/libs/partners/elasticsearch/langchain_elasticsearch/retrievers.py @@ -0,0 +1,97 @@ +import logging +from typing import Any, Callable, Dict, List, Optional + +from elasticsearch import Elasticsearch +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from langchain_elasticsearch._utilities import with_user_agent_header +from langchain_elasticsearch.client import create_elasticsearch_client + +logger = logging.getLogger(__name__) + + +class ElasticsearchRetriever(BaseRetriever): + """ + Elasticsearch retriever + + Args: + es_client: Elasticsearch client connection. Alternatively you can use the + `from_es_params` method with parameters to initialize the client. + index_name: The name of the index to query. + body_func: Function to create an Elasticsearch DSL query body from a search + string. All parameters (including for example the `size` parameter to limit + the number of results) must also be set in the body. + content_field: The document field name that contains the page content. + document_mapper: Function to map Elasticsearch hits to LangChain Documents. + """ + + es_client: Elasticsearch + index_name: str + body_func: Callable[[str], Dict] + content_field: Optional[str] = None + document_mapper: Optional[Callable[[Dict], Document]] = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if self.content_field is None and self.document_mapper is None: + raise ValueError("One of content_field or document_mapper must be defined.") + if self.content_field is not None and self.document_mapper is not None: + raise ValueError( + "Both content_field and document_mapper are defined. " + "Please provide only one." + ) + + self.document_mapper = self.document_mapper or self._field_mapper + self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") + + @staticmethod + def from_es_params( + index_name: str, + body_func: Callable[[str], Dict], + content_field: Optional[str] = None, + document_mapper: Optional[Callable[[Dict], Document]] = None, + url: Optional[str] = None, + cloud_id: Optional[str] = None, + api_key: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> "ElasticsearchRetriever": + client = None + try: + client = create_elasticsearch_client( + url=url, + cloud_id=cloud_id, + api_key=api_key, + username=username, + password=password, + params=params, + ) + except Exception as err: + logger.error(f"Error connecting to Elasticsearch: {err}") + raise err + + return ElasticsearchRetriever( + es_client=client, + index_name=index_name, + body_func=body_func, + content_field=content_field, + document_mapper=document_mapper, + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + if not self.es_client or not self.document_mapper: + raise ValueError("faulty configuration") # should not happen + + body = self.body_func(query) + results = self.es_client.search(index=self.index_name, body=body) + return [self.document_mapper(hit) for hit in results["hits"]["hits"]] + + def _field_mapper(self, hit: Dict[str, Any]) -> Document: + content = hit["_source"].pop(self.content_field) + return Document(page_content=content, metadata=hit) diff --git a/libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py b/libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py index 3d2522b2d16..a6fb2aacdc1 100644 --- a/libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py +++ b/libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py @@ -23,6 +23,7 @@ from langchain_core.vectorstores import VectorStore from langchain_elasticsearch._utilities import ( DistanceStrategy, maximal_marginal_relevance, + with_user_agent_header, ) logger = logging.getLogger(__name__) @@ -526,9 +527,7 @@ class ElasticsearchStore(VectorStore): self.strategy = strategy if es_connection is not None: - headers = dict(es_connection._headers) - headers.update({"user-agent": self.get_user_agent()}) - self.client = es_connection.options(headers=headers) + self.client = es_connection elif es_url is not None or es_cloud_id is not None: self.client = ElasticsearchStore.connect_to_elasticsearch( es_url=es_url, @@ -544,11 +543,7 @@ class ElasticsearchStore(VectorStore): or valid credentials for creating a new connection.""" ) - @staticmethod - def get_user_agent() -> str: - from langchain_core import __version__ - - return f"langchain-py-vs/{__version__}" + self.client = with_user_agent_header(self.client, "langchain-py-vs") @staticmethod def connect_to_elasticsearch( @@ -582,10 +577,7 @@ class ElasticsearchStore(VectorStore): if es_params is not None: connection_params.update(es_params) - es_client = Elasticsearch( - **connection_params, - headers={"user-agent": ElasticsearchStore.get_user_agent()}, - ) + es_client = Elasticsearch(**connection_params) try: es_client.info() except Exception as e: diff --git a/libs/partners/elasticsearch/poetry.lock b/libs/partners/elasticsearch/poetry.lock index 1eb81e91cd3..de9f74eca1f 100644 --- a/libs/partners/elasticsearch/poetry.lock +++ b/libs/partners/elasticsearch/poetry.lock @@ -599,7 +599,7 @@ files = [ [[package]] name = "langchain" -version = "0.1.10" +version = "0.1.11" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -612,9 +612,9 @@ async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} dataclasses-json = ">= 0.5.7, < 0.7" jsonpatch = "^1.33" langchain-community = ">=0.0.25,<0.1" -langchain-core = ">=0.1.28,<0.2" +langchain-core = ">=0.1.29,<0.2" langchain-text-splitters = ">=0.0.1,<0.1" -langsmith = "^0.1.14" +langsmith = "^0.1.17" numpy = "^1" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -671,7 +671,7 @@ url = "../../community" [[package]] name = "langchain-core" -version = "0.1.28" +version = "0.1.29" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -716,13 +716,13 @@ url = "../../text-splitters" [[package]] name = "langsmith" -version = "0.1.14" +version = "0.1.21" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.1.14-py3-none-any.whl", hash = "sha256:ecb243057d2a43c2da0524fe395585bc3421bb5d24f1cdd53eb06fbe63e43a69"}, - {file = "langsmith-0.1.14.tar.gz", hash = "sha256:b95f267d25681f4c9862bb68236fba8a57a60ec7921ecfdaa125936807e51bde"}, + {file = "langsmith-0.1.21-py3-none-any.whl", hash = "sha256:ac3d455d9651879ed306500a0504a2b9b9909225ab178e2446a8bace75e65e23"}, + {file = "langsmith-0.1.21.tar.gz", hash = "sha256:eef6b8a0d3bec7fcfc69ac5b35a16365ffac025dab0c1a4d77d6a7f7d3bbd3de"}, ] [package.dependencies] @@ -732,13 +732,13 @@ requests = ">=2,<3" [[package]] name = "marshmallow" -version = "3.21.0" +version = "3.21.1" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.8" files = [ - {file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"}, - {file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"}, + {file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"}, + {file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"}, ] [package.dependencies] @@ -1033,13 +1033,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pydantic" -version = "2.6.2" +version = "2.6.3" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, - {file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, + {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"}, + {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"}, ] [package.dependencies] @@ -1215,13 +1215,13 @@ watchdog = ">=2.0.0" [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies] @@ -1252,6 +1252,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1357,60 +1358,60 @@ files = [ [[package]] name = "sqlalchemy" -version = "2.0.27" +version = "2.0.28" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d04e579e911562f1055d26dab1868d3e0bb905db3bccf664ee8ad109f035618a"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fa67d821c1fd268a5a87922ef4940442513b4e6c377553506b9db3b83beebbd8"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c7a596d0be71b7baa037f4ac10d5e057d276f65a9a611c46970f012752ebf2d"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:954d9735ee9c3fa74874c830d089a815b7b48df6f6b6e357a74130e478dbd951"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5cd20f58c29bbf2680039ff9f569fa6d21453fbd2fa84dbdb4092f006424c2e6"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:03f448ffb731b48323bda68bcc93152f751436ad6037f18a42b7e16af9e91c07"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-win32.whl", hash = "sha256:d997c5938a08b5e172c30583ba6b8aad657ed9901fc24caf3a7152eeccb2f1b4"}, - {file = "SQLAlchemy-2.0.27-cp310-cp310-win_amd64.whl", hash = "sha256:eb15ef40b833f5b2f19eeae65d65e191f039e71790dd565c2af2a3783f72262f"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c5bad7c60a392850d2f0fee8f355953abaec878c483dd7c3836e0089f046bf6"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3012ab65ea42de1be81fff5fb28d6db893ef978950afc8130ba707179b4284a"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbcd77c4d94b23e0753c5ed8deba8c69f331d4fd83f68bfc9db58bc8983f49cd"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d177b7e82f6dd5e1aebd24d9c3297c70ce09cd1d5d37b43e53f39514379c029c"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:680b9a36029b30cf063698755d277885d4a0eab70a2c7c6e71aab601323cba45"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1306102f6d9e625cebaca3d4c9c8f10588735ef877f0360b5cdb4fdfd3fd7131"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-win32.whl", hash = "sha256:5b78aa9f4f68212248aaf8943d84c0ff0f74efc65a661c2fc68b82d498311fd5"}, - {file = "SQLAlchemy-2.0.27-cp311-cp311-win_amd64.whl", hash = "sha256:15e19a84b84528f52a68143439d0c7a3a69befcd4f50b8ef9b7b69d2628ae7c4"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0de1263aac858f288a80b2071990f02082c51d88335a1db0d589237a3435fe71"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce850db091bf7d2a1f2fdb615220b968aeff3849007b1204bf6e3e50a57b3d32"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dfc936870507da96aebb43e664ae3a71a7b96278382bcfe84d277b88e379b18"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4fbe6a766301f2e8a4519f4500fe74ef0a8509a59e07a4085458f26228cd7cc"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4535c49d961fe9a77392e3a630a626af5baa967172d42732b7a43496c8b28876"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0fb3bffc0ced37e5aa4ac2416f56d6d858f46d4da70c09bb731a246e70bff4d5"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-win32.whl", hash = "sha256:7f470327d06400a0aa7926b375b8e8c3c31d335e0884f509fe272b3c700a7254"}, - {file = "SQLAlchemy-2.0.27-cp312-cp312-win_amd64.whl", hash = "sha256:f9374e270e2553653d710ece397df67db9d19c60d2647bcd35bfc616f1622dcd"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e97cf143d74a7a5a0f143aa34039b4fecf11343eed66538610debc438685db4a"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7b5a3e2120982b8b6bd1d5d99e3025339f7fb8b8267551c679afb39e9c7c7f1"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36aa62b765cf9f43a003233a8c2d7ffdeb55bc62eaa0a0380475b228663a38f"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5ada0438f5b74c3952d916c199367c29ee4d6858edff18eab783b3978d0db16d"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b1d9d1bfd96eef3c3faedb73f486c89e44e64e40e5bfec304ee163de01cf996f"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-win32.whl", hash = "sha256:ca891af9f3289d24a490a5fde664ea04fe2f4984cd97e26de7442a4251bd4b7c"}, - {file = "SQLAlchemy-2.0.27-cp37-cp37m-win_amd64.whl", hash = "sha256:fd8aafda7cdff03b905d4426b714601c0978725a19efc39f5f207b86d188ba01"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec1f5a328464daf7a1e4e385e4f5652dd9b1d12405075ccba1df842f7774b4fc"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad862295ad3f644e3c2c0d8b10a988e1600d3123ecb48702d2c0f26771f1c396"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48217be1de7d29a5600b5c513f3f7664b21d32e596d69582be0a94e36b8309cb"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e56afce6431450442f3ab5973156289bd5ec33dd618941283847c9fd5ff06bf"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:611068511b5531304137bcd7fe8117c985d1b828eb86043bd944cebb7fae3910"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b86abba762ecfeea359112b2bb4490802b340850bbee1948f785141a5e020de8"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-win32.whl", hash = "sha256:30d81cc1192dc693d49d5671cd40cdec596b885b0ce3b72f323888ab1c3863d5"}, - {file = "SQLAlchemy-2.0.27-cp38-cp38-win_amd64.whl", hash = "sha256:120af1e49d614d2525ac247f6123841589b029c318b9afbfc9e2b70e22e1827d"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d07ee7793f2aeb9b80ec8ceb96bc8cc08a2aec8a1b152da1955d64e4825fcbac"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb0845e934647232b6ff5150df37ceffd0b67b754b9fdbb095233deebcddbd4a"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fc19ae2e07a067663dd24fca55f8ed06a288384f0e6e3910420bf4b1270cc51"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b90053be91973a6fb6020a6e44382c97739736a5a9d74e08cc29b196639eb979"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2f5c9dfb0b9ab5e3a8a00249534bdd838d943ec4cfb9abe176a6c33408430230"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33e8bde8fff203de50399b9039c4e14e42d4d227759155c21f8da4a47fc8053c"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-win32.whl", hash = "sha256:d873c21b356bfaf1589b89090a4011e6532582b3a8ea568a00e0c3aab09399dd"}, - {file = "SQLAlchemy-2.0.27-cp39-cp39-win_amd64.whl", hash = "sha256:ff2f1b7c963961d41403b650842dc2039175b906ab2093635d8319bef0b7d620"}, - {file = "SQLAlchemy-2.0.27-py3-none-any.whl", hash = "sha256:1ab4e0448018d01b142c916cc7119ca573803a4745cfe341b8f95657812700ac"}, - {file = "SQLAlchemy-2.0.27.tar.gz", hash = "sha256:86a6ed69a71fe6b88bf9331594fa390a2adda4a49b5c06f98e47bf0d392534f8"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0b148ab0438f72ad21cb004ce3bdaafd28465c4276af66df3b9ecd2037bf252"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbda76961eb8f27e6ad3c84d1dc56d5bc61ba8f02bd20fcf3450bd421c2fcc9c"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feea693c452d85ea0015ebe3bb9cd15b6f49acc1a31c28b3c50f4db0f8fb1e71"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da98815f82dce0cb31fd1e873a0cb30934971d15b74e0d78cf21f9e1b05953f"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a5adf383c73f2d49ad15ff363a8748319ff84c371eed59ffd0127355d6ea1da"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56856b871146bfead25fbcaed098269d90b744eea5cb32a952df00d542cdd368"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-win32.whl", hash = "sha256:943aa74a11f5806ab68278284a4ddd282d3fb348a0e96db9b42cb81bf731acdc"}, + {file = "SQLAlchemy-2.0.28-cp310-cp310-win_amd64.whl", hash = "sha256:c6c4da4843e0dabde41b8f2e8147438330924114f541949e6318358a56d1875a"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46a3d4e7a472bfff2d28db838669fc437964e8af8df8ee1e4548e92710929adc"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0d3dd67b5d69794cfe82862c002512683b3db038b99002171f624712fa71aeaa"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61e2e41656a673b777e2f0cbbe545323dbe0d32312f590b1bc09da1de6c2a02"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0315d9125a38026227f559488fe7f7cee1bd2fbc19f9fd637739dc50bb6380b2"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af8ce2d31679006e7b747d30a89cd3ac1ec304c3d4c20973f0f4ad58e2d1c4c9"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:81ba314a08c7ab701e621b7ad079c0c933c58cdef88593c59b90b996e8b58fa5"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-win32.whl", hash = "sha256:1ee8bd6d68578e517943f5ebff3afbd93fc65f7ef8f23becab9fa8fb315afb1d"}, + {file = "SQLAlchemy-2.0.28-cp311-cp311-win_amd64.whl", hash = "sha256:ad7acbe95bac70e4e687a4dc9ae3f7a2f467aa6597049eeb6d4a662ecd990bb6"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d3499008ddec83127ab286c6f6ec82a34f39c9817f020f75eca96155f9765097"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b66fcd38659cab5d29e8de5409cdf91e9986817703e1078b2fdaad731ea66f5"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea30da1e76cb1acc5b72e204a920a3a7678d9d52f688f087dc08e54e2754c67"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:124202b4e0edea7f08a4db8c81cc7859012f90a0d14ba2bf07c099aff6e96462"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e23b88c69497a6322b5796c0781400692eca1ae5532821b39ce81a48c395aae9"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b6303bfd78fb3221847723104d152e5972c22367ff66edf09120fcde5ddc2e2"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-win32.whl", hash = "sha256:a921002be69ac3ab2cf0c3017c4e6a3377f800f1fca7f254c13b5f1a2f10022c"}, + {file = "SQLAlchemy-2.0.28-cp312-cp312-win_amd64.whl", hash = "sha256:b4a2cf92995635b64876dc141af0ef089c6eea7e05898d8d8865e71a326c0385"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e91b5e341f8c7f1e5020db8e5602f3ed045a29f8e27f7f565e0bdee3338f2c7"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c7b78dfc7278329f27be02c44abc0d69fe235495bb8e16ec7ef1b1a17952db"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eba73ef2c30695cb7eabcdb33bb3d0b878595737479e152468f3ba97a9c22a4"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5df5d1dafb8eee89384fb7a1f79128118bc0ba50ce0db27a40750f6f91aa99d5"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2858bbab1681ee5406650202950dc8f00e83b06a198741b7c656e63818633526"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-win32.whl", hash = "sha256:9461802f2e965de5cff80c5a13bc945abea7edaa1d29360b485c3d2b56cdb075"}, + {file = "SQLAlchemy-2.0.28-cp37-cp37m-win_amd64.whl", hash = "sha256:a6bec1c010a6d65b3ed88c863d56b9ea5eeefdf62b5e39cafd08c65f5ce5198b"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:843a882cadebecc655a68bd9a5b8aa39b3c52f4a9a5572a3036fb1bb2ccdc197"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dbb990612c36163c6072723523d2be7c3eb1517bbdd63fe50449f56afafd1133"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7e4baf9161d076b9a7e432fce06217b9bd90cfb8f1d543d6e8c4595627edb9"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a5354cb4de9b64bccb6ea33162cb83e03dbefa0d892db88a672f5aad638a75"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fffcc8edc508801ed2e6a4e7b0d150a62196fd28b4e16ab9f65192e8186102b6"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aca7b6d99a4541b2ebab4494f6c8c2f947e0df4ac859ced575238e1d6ca5716b"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-win32.whl", hash = "sha256:8c7f10720fc34d14abad5b647bc8202202f4948498927d9f1b4df0fb1cf391b7"}, + {file = "SQLAlchemy-2.0.28-cp38-cp38-win_amd64.whl", hash = "sha256:243feb6882b06a2af68ecf4bec8813d99452a1b62ba2be917ce6283852cf701b"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc4974d3684f28b61b9a90fcb4c41fb340fd4b6a50c04365704a4da5a9603b05"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:87724e7ed2a936fdda2c05dbd99d395c91ea3c96f029a033a4a20e008dd876bf"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68722e6a550f5de2e3cfe9da6afb9a7dd15ef7032afa5651b0f0c6b3adb8815d"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:328529f7c7f90adcd65aed06a161851f83f475c2f664a898af574893f55d9e53"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:df40c16a7e8be7413b885c9bf900d402918cc848be08a59b022478804ea076b8"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:426f2fa71331a64f5132369ede5171c52fd1df1bd9727ce621f38b5b24f48750"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-win32.whl", hash = "sha256:33157920b233bc542ce497a81a2e1452e685a11834c5763933b440fedd1d8e2d"}, + {file = "SQLAlchemy-2.0.28-cp39-cp39-win_amd64.whl", hash = "sha256:2f60843068e432311c886c5f03c4664acaef507cf716f6c60d5fde7265be9d7b"}, + {file = "SQLAlchemy-2.0.28-py3-none-any.whl", hash = "sha256:78bb7e8da0183a8301352d569900d9d3594c48ac21dc1c2ec6b3121ed8b6c986"}, + {file = "SQLAlchemy-2.0.28.tar.gz", hash = "sha256:dd53b6c4e6d960600fd6532b79ee28e2da489322fcf6648738134587faf767b6"}, ] [package.dependencies] diff --git a/libs/partners/elasticsearch/tests/integration_tests/_test_utilities.py b/libs/partners/elasticsearch/tests/integration_tests/_test_utilities.py new file mode 100644 index 00000000000..742a0a5d121 --- /dev/null +++ b/libs/partners/elasticsearch/tests/integration_tests/_test_utilities.py @@ -0,0 +1,42 @@ +import os +from typing import Any, Dict, List + +from elastic_transport import Transport +from elasticsearch import Elasticsearch + + +def clear_test_indices(es: Elasticsearch) -> None: + index_names = es.indices.get(index="_all").keys() + for index_name in index_names: + if index_name.startswith("test_"): + es.indices.delete(index=index_name) + es.indices.refresh(index="_all") + + +def requests_saving_es_client() -> Elasticsearch: + class CustomTransport(Transport): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.requests: List[Dict] = [] + + def perform_request(self, *args, **kwargs): # type: ignore + self.requests.append(kwargs) + return super().perform_request(*args, **kwargs) + + es_url = os.environ.get("ES_URL", "http://localhost:9200") + cloud_id = os.environ.get("ES_CLOUD_ID") + api_key = os.environ.get("ES_API_KEY") + + if cloud_id: + # Running this integration test with Elastic Cloud + # Required for in-stack inference testing (ELSER + model_id) + es = Elasticsearch( + cloud_id=cloud_id, + api_key=api_key, + transport_class=CustomTransport, + ) + else: + # Running this integration test with local docker instance + es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport) + + return es diff --git a/libs/partners/elasticsearch/tests/integration_tests/docker-compose.yml b/libs/partners/elasticsearch/tests/integration_tests/docker-compose.yml new file mode 100644 index 00000000000..b39daa6ffaa --- /dev/null +++ b/libs/partners/elasticsearch/tests/integration_tests/docker-compose.yml @@ -0,0 +1,35 @@ +version: "3" + +services: + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch + environment: + - discovery.type=single-node + - xpack.security.enabled=false # security has been disabled, so no login or password is required. + - xpack.security.http.ssl.enabled=false + - xpack.license.self_generated.type=trial + ports: + - "9200:9200" + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:9200/_cluster/health || exit 1" + ] + interval: 10s + retries: 60 + + kibana: + image: docker.elastic.co/kibana/kibana:8.12.1 + environment: + - ELASTICSEARCH_URL=http://elasticsearch:9200 + ports: + - "5601:5601" + healthcheck: + test: + [ + "CMD-SHELL", + "curl --silent --fail http://localhost:5601/login || exit 1" + ] + interval: 10s + retries: 60 diff --git a/libs/partners/elasticsearch/tests/integration_tests/test_chat_history.py b/libs/partners/elasticsearch/tests/integration_tests/test_chat_history.py index 5ddbcc4eb6b..8c13b2d0d63 100644 --- a/libs/partners/elasticsearch/tests/integration_tests/test_chat_history.py +++ b/libs/partners/elasticsearch/tests/integration_tests/test_chat_history.py @@ -10,8 +10,8 @@ from langchain_core.messages import message_to_dict from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory """ -cd tests/integration_tests/memory/docker-compose -docker-compose -f elasticsearch.yml up +cd tests/integration_tests +docker-compose up elasticsearch By default runs against local docker instance of Elasticsearch. To run against Elastic Cloud, set the following environment variables: diff --git a/libs/partners/elasticsearch/tests/integration_tests/test_retrievers.py b/libs/partners/elasticsearch/tests/integration_tests/test_retrievers.py new file mode 100644 index 00000000000..79f9d8ef1d5 --- /dev/null +++ b/libs/partners/elasticsearch/tests/integration_tests/test_retrievers.py @@ -0,0 +1,169 @@ +"""Test ElasticsearchRetriever functionality.""" + +import re +import uuid +from typing import Any, Dict + +import pytest +from elasticsearch import Elasticsearch +from langchain_core.documents import Document + +from langchain_elasticsearch.retrievers import ElasticsearchRetriever + +from ._test_utilities import requests_saving_es_client + +""" +cd tests/integration_tests +docker-compose up elasticsearch + +By default runs against local docker instance of Elasticsearch. +To run against Elastic Cloud, set the following environment variables: +- ES_CLOUD_ID +- ES_API_KEY +""" + + +def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None: + docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")] + for identifier, text in docs: + es_client.index( + index=index_name, + document={field_name: text, "another_field": 1}, + id=str(identifier), + refresh=True, + ) + + +class TestElasticsearchRetriever: + @pytest.fixture(scope="function") + def es_client(self) -> Any: + return requests_saving_es_client() + + @pytest.fixture(scope="function") + def index_name(self) -> str: + """Return the index name.""" + return f"test_{uuid.uuid4().hex}" + + def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None: + """Test that the user agent header is set correctly.""" + + retriever = ElasticsearchRetriever( + index_name=index_name, + body_func=lambda _: {"query": {"match_all": {}}}, + content_field="text", + es_client=es_client, + ) + + assert retriever.es_client + user_agent = retriever.es_client._headers["User-Agent"] + assert ( + re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None + ), f"The string '{user_agent}' does not match the expected pattern." + + index_test_data(es_client, index_name, "text") + retriever.get_relevant_documents("foo") + + search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] + user_agent = search_request["headers"]["User-Agent"] + assert ( + re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None + ), f"The string '{user_agent}' does not match the expected pattern." + + def test_init_url(self, index_name: str) -> None: + """Test end-to-end indexing and search.""" + + text_field = "text" + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + retriever = ElasticsearchRetriever.from_es_params( + url="http://localhost:9200", + index_name=index_name, + body_func=body_func, + content_field=text_field, + ) + + index_test_data(retriever.es_client, index_name, text_field) + result = retriever.get_relevant_documents("foo") + + assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} + assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} + for r in result: + assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} + assert text_field not in r.metadata["_source"] + assert "another_field" in r.metadata["_source"] + + def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None: + """Test end-to-end indexing and search.""" + + text_field = "text" + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name=index_name, + body_func=body_func, + content_field=text_field, + es_client=es_client, + ) + + index_test_data(es_client, index_name, text_field) + result = retriever.get_relevant_documents("foo") + + assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} + assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} + for r in result: + assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} + assert text_field not in r.metadata["_source"] + assert "another_field" in r.metadata["_source"] + + def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None: + """Test custom document maper""" + + text_field = "text" + meta = {"some_field": 12} + + def body_func(query: str) -> Dict: + return {"query": {"match": {text_field: {"query": query}}}} + + def id_as_content(hit: Dict) -> Document: + return Document(page_content=hit["_id"], metadata=meta) + + retriever = ElasticsearchRetriever( + index_name=index_name, + body_func=body_func, + document_mapper=id_as_content, + es_client=es_client, + ) + + index_test_data(es_client, index_name, text_field) + result = retriever.get_relevant_documents("foo") + + assert [r.page_content for r in result] == ["3", "1", "5"] + assert [r.metadata for r in result] == [meta, meta, meta] + + def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None: + """Raise exception if both content_field and document_mapper are specified.""" + + with pytest.raises(ValueError): + ElasticsearchRetriever( + content_field="text", + document_mapper=lambda x: x, + index_name="foo", + body_func=lambda x: x, + es_client=es_client, + ) + + def test_fail_neither_content_field_nor_mapper( + self, es_client: Elasticsearch + ) -> None: + """Raise exception if neither content_field nor document_mapper are specified""" + + with pytest.raises(ValueError): + ElasticsearchRetriever( + index_name="foo", + body_func=lambda x: x, + es_client=es_client, + ) diff --git a/libs/partners/elasticsearch/tests/integration_tests/test_vectorstores.py b/libs/partners/elasticsearch/tests/integration_tests/test_vectorstores.py index c46fc865595..d988b15395f 100644 --- a/libs/partners/elasticsearch/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/elasticsearch/tests/integration_tests/test_vectorstores.py @@ -7,7 +7,6 @@ import uuid from typing import Any, Dict, Generator, List, Union import pytest -from elastic_transport import Transport from elasticsearch import Elasticsearch from elasticsearch.helpers import BulkIndexError from langchain_core.documents import Document @@ -18,12 +17,13 @@ from ..fake_embeddings import ( ConsistentFakeEmbeddings, FakeEmbeddings, ) +from ._test_utilities import clear_test_indices, requests_saving_es_client logging.basicConfig(level=logging.DEBUG) """ -cd tests/integration_tests/vectorstores/docker-compose -docker-compose -f elasticsearch.yml up +cd tests/integration_tests +docker-compose up elasticsearch By default runs against local docker instance of Elasticsearch. To run against Elastic Cloud, set the following environment variables: @@ -74,12 +74,8 @@ class TestElasticsearch: es = Elasticsearch(hosts=es_url) yield {"es_url": es_url} - # Clear all indexes - index_names = es.indices.get(index="_all").keys() - for index_name in index_names: - if index_name.startswith("test_"): - es.indices.delete(index=index_name) - es.indices.refresh(index="_all") + # clear indices + clear_test_indices(es) # clear all test pipelines try: @@ -94,32 +90,11 @@ class TestElasticsearch: except Exception: pass + return None + @pytest.fixture(scope="function") def es_client(self) -> Any: - class CustomTransport(Transport): - requests = [] - - def perform_request(self, *args, **kwargs): # type: ignore - self.requests.append(kwargs) - return super().perform_request(*args, **kwargs) - - es_url = os.environ.get("ES_URL", "http://localhost:9200") - cloud_id = os.environ.get("ES_CLOUD_ID") - api_key = os.environ.get("ES_API_KEY") - - if cloud_id: - # Running this integration test with Elastic Cloud - # Required for in-stack inference testing (ELSER + model_id) - es = Elasticsearch( - cloud_id=cloud_id, - api_key=api_key, - transport_class=CustomTransport, - ) - return es - else: - # Running this integration test with local docker instance - es = Elasticsearch(hosts=es_url, transport_class=CustomTransport) - return es + return requests_saving_es_client() @pytest.fixture(scope="function") def index_name(self) -> str: @@ -887,11 +862,8 @@ class TestElasticsearch: ) user_agent = es_client.transport.requests[0]["headers"]["User-Agent"] - pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$" - match = re.match(pattern, user_agent) - assert ( - match is not None + re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None ), f"The string '{user_agent}' does not match the expected pattern." def test_elasticsearch_with_internal_user_agent( @@ -908,15 +880,12 @@ class TestElasticsearch: ) user_agent = store.client._headers["User-Agent"] - pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$" - match = re.match(pattern, user_agent) - assert ( - match is not None + re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None ), f"The string '{user_agent}' does not match the expected pattern." def test_bulk_args(self, es_client: Any, index_name: str) -> None: - """Test to make sure the user-agent is set correctly.""" + """Test to make sure the bulk arguments work as expected.""" texts = ["foo", "bob", "baz"] ElasticsearchStore.from_texts(