elasticsearch: add ElasticsearchRetriever (#18587)

Implement
[Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/)
interface for Elasticsearch.

I opted to only expose the `body`, which gives you full flexibility, and
none the other 68 arguments of the [search
method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search).

Added a user agent header for usage tracking in Elastic Cloud.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Max Jakob 2024-03-06 01:42:50 +01:00 committed by GitHub
parent 8bc347c5fc
commit ee7a7954b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 493 additions and 187 deletions

View File

@ -2,6 +2,8 @@ from enum import Enum
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
from elasticsearch import Elasticsearch
from langchain_core import __version__ as langchain_version
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
@ -17,6 +19,12 @@ class DistanceStrategy(str, Enum):
COSINE = "COSINE" COSINE = "COSINE"
def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
headers = dict(client._headers)
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
return client.options(headers=headers)
def maximal_marginal_relevance( def maximal_marginal_relevance(
query_embedding: np.ndarray, query_embedding: np.ndarray,
embedding_list: list, embedding_list: list,

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
from time import time from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import ( from langchain_core.messages import (
@ -10,6 +10,9 @@ from langchain_core.messages import (
messages_from_dict, messages_from_dict,
) )
from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@ -51,23 +54,27 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
# Initialize Elasticsearch client from passed client arg or connection info # Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None: if es_connection is not None:
self.client = es_connection.options( self.client = es_connection
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None: elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch( try:
es_url=es_url, self.client = create_elasticsearch_client(
url=es_url,
username=es_user, username=es_user,
password=es_password, password=es_password,
cloud_id=es_cloud_id, cloud_id=es_cloud_id,
api_key=es_api_key, api_key=es_api_key,
) )
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
else: else:
raise ValueError( raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \ """Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection.""" or valid credentials for creating a new connection."""
) )
self.client = with_user_agent_header(self.client, "langchain-py-ms")
if self.client.indices.exists(index=index): if self.client.indices.exists(index=index):
logger.debug( logger.debug(
f"Chat history index {index} already exists, skipping creation." f"Chat history index {index} already exists, skipping creation."
@ -86,60 +93,6 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
}, },
) )
@staticmethod
def get_user_agent() -> str:
from langchain_core import __version__
return f"langchain-py-ms/{__version__}"
@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return es_client
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore[override] def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch""" """Retrieve the messages from Elasticsearch"""

View File

@ -0,0 +1,40 @@
from typing import Any, Dict, Optional
from elasticsearch import Elasticsearch
def create_elasticsearch_client(
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Elasticsearch:
if url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if url:
connection_params["hosts"] = [url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
if params is not None:
connection_params.update(params)
es_client = Elasticsearch(**connection_params)
es_client.info() # test connection
return es_client

View File

@ -0,0 +1,97 @@
import logging
from typing import Any, Callable, Dict, List, Optional
from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client
logger = logging.getLogger(__name__)
class ElasticsearchRetriever(BaseRetriever):
"""
Elasticsearch retriever
Args:
es_client: Elasticsearch client connection. Alternatively you can use the
`from_es_params` method with parameters to initialize the client.
index_name: The name of the index to query.
body_func: Function to create an Elasticsearch DSL query body from a search
string. All parameters (including for example the `size` parameter to limit
the number of results) must also be set in the body.
content_field: The document field name that contains the page content.
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""
es_client: Elasticsearch
index_name: str
body_func: Callable[[str], Dict]
content_field: Optional[str] = None
document_mapper: Optional[Callable[[Dict], Document]] = None
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
if self.content_field is None and self.document_mapper is None:
raise ValueError("One of content_field or document_mapper must be defined.")
if self.content_field is not None and self.document_mapper is not None:
raise ValueError(
"Both content_field and document_mapper are defined. "
"Please provide only one."
)
self.document_mapper = self.document_mapper or self._field_mapper
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")
@staticmethod
def from_es_params(
index_name: str,
body_func: Callable[[str], Dict],
content_field: Optional[str] = None,
document_mapper: Optional[Callable[[Dict], Document]] = None,
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> "ElasticsearchRetriever":
client = None
try:
client = create_elasticsearch_client(
url=url,
cloud_id=cloud_id,
api_key=api_key,
username=username,
password=password,
params=params,
)
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return ElasticsearchRetriever(
es_client=client,
index_name=index_name,
body_func=body_func,
content_field=content_field,
document_mapper=document_mapper,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen
body = self.body_func(query)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]
def _field_mapper(self, hit: Dict[str, Any]) -> Document:
content = hit["_source"].pop(self.content_field)
return Document(page_content=content, metadata=hit)

View File

@ -23,6 +23,7 @@ from langchain_core.vectorstores import VectorStore
from langchain_elasticsearch._utilities import ( from langchain_elasticsearch._utilities import (
DistanceStrategy, DistanceStrategy,
maximal_marginal_relevance, maximal_marginal_relevance,
with_user_agent_header,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -526,9 +527,7 @@ class ElasticsearchStore(VectorStore):
self.strategy = strategy self.strategy = strategy
if es_connection is not None: if es_connection is not None:
headers = dict(es_connection._headers) self.client = es_connection
headers.update({"user-agent": self.get_user_agent()})
self.client = es_connection.options(headers=headers)
elif es_url is not None or es_cloud_id is not None: elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchStore.connect_to_elasticsearch( self.client = ElasticsearchStore.connect_to_elasticsearch(
es_url=es_url, es_url=es_url,
@ -544,11 +543,7 @@ class ElasticsearchStore(VectorStore):
or valid credentials for creating a new connection.""" or valid credentials for creating a new connection."""
) )
@staticmethod self.client = with_user_agent_header(self.client, "langchain-py-vs")
def get_user_agent() -> str:
from langchain_core import __version__
return f"langchain-py-vs/{__version__}"
@staticmethod @staticmethod
def connect_to_elasticsearch( def connect_to_elasticsearch(
@ -582,10 +577,7 @@ class ElasticsearchStore(VectorStore):
if es_params is not None: if es_params is not None:
connection_params.update(es_params) connection_params.update(es_params)
es_client = Elasticsearch( es_client = Elasticsearch(**connection_params)
**connection_params,
headers={"user-agent": ElasticsearchStore.get_user_agent()},
)
try: try:
es_client.info() es_client.info()
except Exception as e: except Exception as e:

View File

@ -599,7 +599,7 @@ files = [
[[package]] [[package]]
name = "langchain" name = "langchain"
version = "0.1.10" version = "0.1.11"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -612,9 +612,9 @@ async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
dataclasses-json = ">= 0.5.7, < 0.7" dataclasses-json = ">= 0.5.7, < 0.7"
jsonpatch = "^1.33" jsonpatch = "^1.33"
langchain-community = ">=0.0.25,<0.1" langchain-community = ">=0.0.25,<0.1"
langchain-core = ">=0.1.28,<0.2" langchain-core = ">=0.1.29,<0.2"
langchain-text-splitters = ">=0.0.1,<0.1" langchain-text-splitters = ">=0.0.1,<0.1"
langsmith = "^0.1.14" langsmith = "^0.1.17"
numpy = "^1" numpy = "^1"
pydantic = ">=1,<3" pydantic = ">=1,<3"
PyYAML = ">=5.3" PyYAML = ">=5.3"
@ -671,7 +671,7 @@ url = "../../community"
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.1.28" version = "0.1.29"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -716,13 +716,13 @@ url = "../../text-splitters"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.14" version = "0.1.21"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
files = [ files = [
{file = "langsmith-0.1.14-py3-none-any.whl", hash = "sha256:ecb243057d2a43c2da0524fe395585bc3421bb5d24f1cdd53eb06fbe63e43a69"}, {file = "langsmith-0.1.21-py3-none-any.whl", hash = "sha256:ac3d455d9651879ed306500a0504a2b9b9909225ab178e2446a8bace75e65e23"},
{file = "langsmith-0.1.14.tar.gz", hash = "sha256:b95f267d25681f4c9862bb68236fba8a57a60ec7921ecfdaa125936807e51bde"}, {file = "langsmith-0.1.21.tar.gz", hash = "sha256:eef6b8a0d3bec7fcfc69ac5b35a16365ffac025dab0c1a4d77d6a7f7d3bbd3de"},
] ]
[package.dependencies] [package.dependencies]
@ -732,13 +732,13 @@ requests = ">=2,<3"
[[package]] [[package]]
name = "marshmallow" name = "marshmallow"
version = "3.21.0" version = "3.21.1"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes." description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "marshmallow-3.21.0-py3-none-any.whl", hash = "sha256:e7997f83571c7fd476042c2c188e4ee8a78900ca5e74bd9c8097afa56624e9bd"}, {file = "marshmallow-3.21.1-py3-none-any.whl", hash = "sha256:f085493f79efb0644f270a9bf2892843142d80d7174bbbd2f3713f2a589dc633"},
{file = "marshmallow-3.21.0.tar.gz", hash = "sha256:20f53be28c6e374a711a16165fb22a8dc6003e3f7cda1285e3ca777b9193885b"}, {file = "marshmallow-3.21.1.tar.gz", hash = "sha256:4e65e9e0d80fc9e609574b9983cf32579f305c718afb30d7233ab818571768c3"},
] ]
[package.dependencies] [package.dependencies]
@ -1033,13 +1033,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.6.2" version = "2.6.3"
description = "Data validation using Python type hints" description = "Data validation using Python type hints"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, {file = "pydantic-2.6.3-py3-none-any.whl", hash = "sha256:72c6034df47f46ccdf81869fddb81aade68056003900a8724a4f160700016a2a"},
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, {file = "pydantic-2.6.3.tar.gz", hash = "sha256:e07805c4c7f5c6826e33a1d4c9d47950d7eaf34868e2690f8594d2e30241f11f"},
] ]
[package.dependencies] [package.dependencies]
@ -1215,13 +1215,13 @@ watchdog = ">=2.0.0"
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.8.2" version = "2.9.0.post0"
description = "Extensions to the standard Python datetime module" description = "Extensions to the standard Python datetime module"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [ files = [
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
] ]
[package.dependencies] [package.dependencies]
@ -1252,6 +1252,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1357,60 +1358,60 @@ files = [
[[package]] [[package]]
name = "sqlalchemy" name = "sqlalchemy"
version = "2.0.27" version = "2.0.28"
description = "Database Abstraction Library" description = "Database Abstraction Library"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d04e579e911562f1055d26dab1868d3e0bb905db3bccf664ee8ad109f035618a"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0b148ab0438f72ad21cb004ce3bdaafd28465c4276af66df3b9ecd2037bf252"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fa67d821c1fd268a5a87922ef4940442513b4e6c377553506b9db3b83beebbd8"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbda76961eb8f27e6ad3c84d1dc56d5bc61ba8f02bd20fcf3450bd421c2fcc9c"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c7a596d0be71b7baa037f4ac10d5e057d276f65a9a611c46970f012752ebf2d"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feea693c452d85ea0015ebe3bb9cd15b6f49acc1a31c28b3c50f4db0f8fb1e71"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:954d9735ee9c3fa74874c830d089a815b7b48df6f6b6e357a74130e478dbd951"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da98815f82dce0cb31fd1e873a0cb30934971d15b74e0d78cf21f9e1b05953f"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5cd20f58c29bbf2680039ff9f569fa6d21453fbd2fa84dbdb4092f006424c2e6"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4a5adf383c73f2d49ad15ff363a8748319ff84c371eed59ffd0127355d6ea1da"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:03f448ffb731b48323bda68bcc93152f751436ad6037f18a42b7e16af9e91c07"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56856b871146bfead25fbcaed098269d90b744eea5cb32a952df00d542cdd368"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-win32.whl", hash = "sha256:d997c5938a08b5e172c30583ba6b8aad657ed9901fc24caf3a7152eeccb2f1b4"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-win32.whl", hash = "sha256:943aa74a11f5806ab68278284a4ddd282d3fb348a0e96db9b42cb81bf731acdc"},
{file = "SQLAlchemy-2.0.27-cp310-cp310-win_amd64.whl", hash = "sha256:eb15ef40b833f5b2f19eeae65d65e191f039e71790dd565c2af2a3783f72262f"}, {file = "SQLAlchemy-2.0.28-cp310-cp310-win_amd64.whl", hash = "sha256:c6c4da4843e0dabde41b8f2e8147438330924114f541949e6318358a56d1875a"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c5bad7c60a392850d2f0fee8f355953abaec878c483dd7c3836e0089f046bf6"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46a3d4e7a472bfff2d28db838669fc437964e8af8df8ee1e4548e92710929adc"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3012ab65ea42de1be81fff5fb28d6db893ef978950afc8130ba707179b4284a"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0d3dd67b5d69794cfe82862c002512683b3db038b99002171f624712fa71aeaa"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbcd77c4d94b23e0753c5ed8deba8c69f331d4fd83f68bfc9db58bc8983f49cd"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61e2e41656a673b777e2f0cbbe545323dbe0d32312f590b1bc09da1de6c2a02"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d177b7e82f6dd5e1aebd24d9c3297c70ce09cd1d5d37b43e53f39514379c029c"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0315d9125a38026227f559488fe7f7cee1bd2fbc19f9fd637739dc50bb6380b2"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:680b9a36029b30cf063698755d277885d4a0eab70a2c7c6e71aab601323cba45"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af8ce2d31679006e7b747d30a89cd3ac1ec304c3d4c20973f0f4ad58e2d1c4c9"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1306102f6d9e625cebaca3d4c9c8f10588735ef877f0360b5cdb4fdfd3fd7131"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:81ba314a08c7ab701e621b7ad079c0c933c58cdef88593c59b90b996e8b58fa5"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win32.whl", hash = "sha256:5b78aa9f4f68212248aaf8943d84c0ff0f74efc65a661c2fc68b82d498311fd5"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-win32.whl", hash = "sha256:1ee8bd6d68578e517943f5ebff3afbd93fc65f7ef8f23becab9fa8fb315afb1d"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win_amd64.whl", hash = "sha256:15e19a84b84528f52a68143439d0c7a3a69befcd4f50b8ef9b7b69d2628ae7c4"}, {file = "SQLAlchemy-2.0.28-cp311-cp311-win_amd64.whl", hash = "sha256:ad7acbe95bac70e4e687a4dc9ae3f7a2f467aa6597049eeb6d4a662ecd990bb6"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0de1263aac858f288a80b2071990f02082c51d88335a1db0d589237a3435fe71"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d3499008ddec83127ab286c6f6ec82a34f39c9817f020f75eca96155f9765097"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce850db091bf7d2a1f2fdb615220b968aeff3849007b1204bf6e3e50a57b3d32"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b66fcd38659cab5d29e8de5409cdf91e9986817703e1078b2fdaad731ea66f5"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dfc936870507da96aebb43e664ae3a71a7b96278382bcfe84d277b88e379b18"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea30da1e76cb1acc5b72e204a920a3a7678d9d52f688f087dc08e54e2754c67"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4fbe6a766301f2e8a4519f4500fe74ef0a8509a59e07a4085458f26228cd7cc"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:124202b4e0edea7f08a4db8c81cc7859012f90a0d14ba2bf07c099aff6e96462"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4535c49d961fe9a77392e3a630a626af5baa967172d42732b7a43496c8b28876"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e23b88c69497a6322b5796c0781400692eca1ae5532821b39ce81a48c395aae9"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0fb3bffc0ced37e5aa4ac2416f56d6d858f46d4da70c09bb731a246e70bff4d5"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b6303bfd78fb3221847723104d152e5972c22367ff66edf09120fcde5ddc2e2"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win32.whl", hash = "sha256:7f470327d06400a0aa7926b375b8e8c3c31d335e0884f509fe272b3c700a7254"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-win32.whl", hash = "sha256:a921002be69ac3ab2cf0c3017c4e6a3377f800f1fca7f254c13b5f1a2f10022c"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win_amd64.whl", hash = "sha256:f9374e270e2553653d710ece397df67db9d19c60d2647bcd35bfc616f1622dcd"}, {file = "SQLAlchemy-2.0.28-cp312-cp312-win_amd64.whl", hash = "sha256:b4a2cf92995635b64876dc141af0ef089c6eea7e05898d8d8865e71a326c0385"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e97cf143d74a7a5a0f143aa34039b4fecf11343eed66538610debc438685db4a"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e91b5e341f8c7f1e5020db8e5602f3ed045a29f8e27f7f565e0bdee3338f2c7"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7b5a3e2120982b8b6bd1d5d99e3025339f7fb8b8267551c679afb39e9c7c7f1"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45c7b78dfc7278329f27be02c44abc0d69fe235495bb8e16ec7ef1b1a17952db"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36aa62b765cf9f43a003233a8c2d7ffdeb55bc62eaa0a0380475b228663a38f"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eba73ef2c30695cb7eabcdb33bb3d0b878595737479e152468f3ba97a9c22a4"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5ada0438f5b74c3952d916c199367c29ee4d6858edff18eab783b3978d0db16d"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5df5d1dafb8eee89384fb7a1f79128118bc0ba50ce0db27a40750f6f91aa99d5"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b1d9d1bfd96eef3c3faedb73f486c89e44e64e40e5bfec304ee163de01cf996f"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2858bbab1681ee5406650202950dc8f00e83b06a198741b7c656e63818633526"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win32.whl", hash = "sha256:ca891af9f3289d24a490a5fde664ea04fe2f4984cd97e26de7442a4251bd4b7c"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-win32.whl", hash = "sha256:9461802f2e965de5cff80c5a13bc945abea7edaa1d29360b485c3d2b56cdb075"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win_amd64.whl", hash = "sha256:fd8aafda7cdff03b905d4426b714601c0978725a19efc39f5f207b86d188ba01"}, {file = "SQLAlchemy-2.0.28-cp37-cp37m-win_amd64.whl", hash = "sha256:a6bec1c010a6d65b3ed88c863d56b9ea5eeefdf62b5e39cafd08c65f5ce5198b"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec1f5a328464daf7a1e4e385e4f5652dd9b1d12405075ccba1df842f7774b4fc"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:843a882cadebecc655a68bd9a5b8aa39b3c52f4a9a5572a3036fb1bb2ccdc197"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad862295ad3f644e3c2c0d8b10a988e1600d3123ecb48702d2c0f26771f1c396"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dbb990612c36163c6072723523d2be7c3eb1517bbdd63fe50449f56afafd1133"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48217be1de7d29a5600b5c513f3f7664b21d32e596d69582be0a94e36b8309cb"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7e4baf9161d076b9a7e432fce06217b9bd90cfb8f1d543d6e8c4595627edb9"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e56afce6431450442f3ab5973156289bd5ec33dd618941283847c9fd5ff06bf"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a5354cb4de9b64bccb6ea33162cb83e03dbefa0d892db88a672f5aad638a75"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:611068511b5531304137bcd7fe8117c985d1b828eb86043bd944cebb7fae3910"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fffcc8edc508801ed2e6a4e7b0d150a62196fd28b4e16ab9f65192e8186102b6"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b86abba762ecfeea359112b2bb4490802b340850bbee1948f785141a5e020de8"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aca7b6d99a4541b2ebab4494f6c8c2f947e0df4ac859ced575238e1d6ca5716b"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win32.whl", hash = "sha256:30d81cc1192dc693d49d5671cd40cdec596b885b0ce3b72f323888ab1c3863d5"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-win32.whl", hash = "sha256:8c7f10720fc34d14abad5b647bc8202202f4948498927d9f1b4df0fb1cf391b7"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win_amd64.whl", hash = "sha256:120af1e49d614d2525ac247f6123841589b029c318b9afbfc9e2b70e22e1827d"}, {file = "SQLAlchemy-2.0.28-cp38-cp38-win_amd64.whl", hash = "sha256:243feb6882b06a2af68ecf4bec8813d99452a1b62ba2be917ce6283852cf701b"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d07ee7793f2aeb9b80ec8ceb96bc8cc08a2aec8a1b152da1955d64e4825fcbac"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc4974d3684f28b61b9a90fcb4c41fb340fd4b6a50c04365704a4da5a9603b05"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb0845e934647232b6ff5150df37ceffd0b67b754b9fdbb095233deebcddbd4a"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:87724e7ed2a936fdda2c05dbd99d395c91ea3c96f029a033a4a20e008dd876bf"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fc19ae2e07a067663dd24fca55f8ed06a288384f0e6e3910420bf4b1270cc51"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68722e6a550f5de2e3cfe9da6afb9a7dd15ef7032afa5651b0f0c6b3adb8815d"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b90053be91973a6fb6020a6e44382c97739736a5a9d74e08cc29b196639eb979"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:328529f7c7f90adcd65aed06a161851f83f475c2f664a898af574893f55d9e53"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2f5c9dfb0b9ab5e3a8a00249534bdd838d943ec4cfb9abe176a6c33408430230"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:df40c16a7e8be7413b885c9bf900d402918cc848be08a59b022478804ea076b8"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33e8bde8fff203de50399b9039c4e14e42d4d227759155c21f8da4a47fc8053c"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:426f2fa71331a64f5132369ede5171c52fd1df1bd9727ce621f38b5b24f48750"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win32.whl", hash = "sha256:d873c21b356bfaf1589b89090a4011e6532582b3a8ea568a00e0c3aab09399dd"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-win32.whl", hash = "sha256:33157920b233bc542ce497a81a2e1452e685a11834c5763933b440fedd1d8e2d"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win_amd64.whl", hash = "sha256:ff2f1b7c963961d41403b650842dc2039175b906ab2093635d8319bef0b7d620"}, {file = "SQLAlchemy-2.0.28-cp39-cp39-win_amd64.whl", hash = "sha256:2f60843068e432311c886c5f03c4664acaef507cf716f6c60d5fde7265be9d7b"},
{file = "SQLAlchemy-2.0.27-py3-none-any.whl", hash = "sha256:1ab4e0448018d01b142c916cc7119ca573803a4745cfe341b8f95657812700ac"}, {file = "SQLAlchemy-2.0.28-py3-none-any.whl", hash = "sha256:78bb7e8da0183a8301352d569900d9d3594c48ac21dc1c2ec6b3121ed8b6c986"},
{file = "SQLAlchemy-2.0.27.tar.gz", hash = "sha256:86a6ed69a71fe6b88bf9331594fa390a2adda4a49b5c06f98e47bf0d392534f8"}, {file = "SQLAlchemy-2.0.28.tar.gz", hash = "sha256:dd53b6c4e6d960600fd6532b79ee28e2da489322fcf6648738134587faf767b6"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -0,0 +1,42 @@
import os
from typing import Any, Dict, List
from elastic_transport import Transport
from elasticsearch import Elasticsearch
def clear_test_indices(es: Elasticsearch) -> None:
index_names = es.indices.get(index="_all").keys()
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
def requests_saving_es_client() -> Elasticsearch:
class CustomTransport(Transport):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.requests: List[Dict] = []
def perform_request(self, *args, **kwargs): # type: ignore
self.requests.append(kwargs)
return super().perform_request(*args, **kwargs)
es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")
if cloud_id:
# Running this integration test with Elastic Cloud
# Required for in-stack inference testing (ELSER + model_id)
es = Elasticsearch(
cloud_id=cloud_id,
api_key=api_key,
transport_class=CustomTransport,
)
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=[es_url], transport_class=CustomTransport)
return es

View File

@ -0,0 +1,35 @@
version: "3"
services:
elasticsearch:
image: docker.elastic.co/elasticsearch/elasticsearch:8.12.1 # https://www.docker.elastic.co/r/elasticsearch/elasticsearch
environment:
- discovery.type=single-node
- xpack.security.enabled=false # security has been disabled, so no login or password is required.
- xpack.security.http.ssl.enabled=false
- xpack.license.self_generated.type=trial
ports:
- "9200:9200"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:9200/_cluster/health || exit 1"
]
interval: 10s
retries: 60
kibana:
image: docker.elastic.co/kibana/kibana:8.12.1
environment:
- ELASTICSEARCH_URL=http://elasticsearch:9200
ports:
- "5601:5601"
healthcheck:
test:
[
"CMD-SHELL",
"curl --silent --fail http://localhost:5601/login || exit 1"
]
interval: 10s
retries: 60

View File

@ -10,8 +10,8 @@ from langchain_core.messages import message_to_dict
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
""" """
cd tests/integration_tests/memory/docker-compose cd tests/integration_tests
docker-compose -f elasticsearch.yml up docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch. By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables: To run against Elastic Cloud, set the following environment variables:

View File

@ -0,0 +1,169 @@
"""Test ElasticsearchRetriever functionality."""
import re
import uuid
from typing import Any, Dict
import pytest
from elasticsearch import Elasticsearch
from langchain_core.documents import Document
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
from ._test_utilities import requests_saving_es_client
"""
cd tests/integration_tests
docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
- ES_CLOUD_ID
- ES_API_KEY
"""
def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None:
docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")]
for identifier, text in docs:
es_client.index(
index=index_name,
document={field_name: text, "another_field": 1},
id=str(identifier),
refresh=True,
)
class TestElasticsearchRetriever:
@pytest.fixture(scope="function")
def es_client(self) -> Any:
return requests_saving_es_client()
@pytest.fixture(scope="function")
def index_name(self) -> str:
"""Return the index name."""
return f"test_{uuid.uuid4().hex}"
def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test that the user agent header is set correctly."""
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=lambda _: {"query": {"match_all": {}}},
content_field="text",
es_client=es_client,
)
assert retriever.es_client
user_agent = retriever.es_client._headers["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
index_test_data(es_client, index_name, "text")
retriever.get_relevant_documents("foo")
search_request = es_client.transport.requests[-1] # type: ignore[attr-defined]
user_agent = search_request["headers"]["User-Agent"]
assert (
re.match(r"^langchain-py-r/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern."
def test_init_url(self, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
retriever = ElasticsearchRetriever.from_es_params(
url="http://localhost:9200",
index_name=index_name,
body_func=body_func,
content_field=text_field,
)
index_test_data(retriever.es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test end-to-end indexing and search."""
text_field = "text"
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
content_field=text_field,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"}
assert {r.metadata["_id"] for r in result} == {"3", "1", "5"}
for r in result:
assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"}
assert text_field not in r.metadata["_source"]
assert "another_field" in r.metadata["_source"]
def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None:
"""Test custom document maper"""
text_field = "text"
meta = {"some_field": 12}
def body_func(query: str) -> Dict:
return {"query": {"match": {text_field: {"query": query}}}}
def id_as_content(hit: Dict) -> Document:
return Document(page_content=hit["_id"], metadata=meta)
retriever = ElasticsearchRetriever(
index_name=index_name,
body_func=body_func,
document_mapper=id_as_content,
es_client=es_client,
)
index_test_data(es_client, index_name, text_field)
result = retriever.get_relevant_documents("foo")
assert [r.page_content for r in result] == ["3", "1", "5"]
assert [r.metadata for r in result] == [meta, meta, meta]
def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None:
"""Raise exception if both content_field and document_mapper are specified."""
with pytest.raises(ValueError):
ElasticsearchRetriever(
content_field="text",
document_mapper=lambda x: x,
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)
def test_fail_neither_content_field_nor_mapper(
self, es_client: Elasticsearch
) -> None:
"""Raise exception if neither content_field nor document_mapper are specified"""
with pytest.raises(ValueError):
ElasticsearchRetriever(
index_name="foo",
body_func=lambda x: x,
es_client=es_client,
)

View File

@ -7,7 +7,6 @@ import uuid
from typing import Any, Dict, Generator, List, Union from typing import Any, Dict, Generator, List, Union
import pytest import pytest
from elastic_transport import Transport
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch.helpers import BulkIndexError from elasticsearch.helpers import BulkIndexError
from langchain_core.documents import Document from langchain_core.documents import Document
@ -18,12 +17,13 @@ from ..fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
FakeEmbeddings, FakeEmbeddings,
) )
from ._test_utilities import clear_test_indices, requests_saving_es_client
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
""" """
cd tests/integration_tests/vectorstores/docker-compose cd tests/integration_tests
docker-compose -f elasticsearch.yml up docker-compose up elasticsearch
By default runs against local docker instance of Elasticsearch. By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables: To run against Elastic Cloud, set the following environment variables:
@ -74,12 +74,8 @@ class TestElasticsearch:
es = Elasticsearch(hosts=es_url) es = Elasticsearch(hosts=es_url)
yield {"es_url": es_url} yield {"es_url": es_url}
# Clear all indexes # clear indices
index_names = es.indices.get(index="_all").keys() clear_test_indices(es)
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
# clear all test pipelines # clear all test pipelines
try: try:
@ -94,32 +90,11 @@ class TestElasticsearch:
except Exception: except Exception:
pass pass
return None
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def es_client(self) -> Any: def es_client(self) -> Any:
class CustomTransport(Transport): return requests_saving_es_client()
requests = []
def perform_request(self, *args, **kwargs): # type: ignore
self.requests.append(kwargs)
return super().perform_request(*args, **kwargs)
es_url = os.environ.get("ES_URL", "http://localhost:9200")
cloud_id = os.environ.get("ES_CLOUD_ID")
api_key = os.environ.get("ES_API_KEY")
if cloud_id:
# Running this integration test with Elastic Cloud
# Required for in-stack inference testing (ELSER + model_id)
es = Elasticsearch(
cloud_id=cloud_id,
api_key=api_key,
transport_class=CustomTransport,
)
return es
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=es_url, transport_class=CustomTransport)
return es
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def index_name(self) -> str: def index_name(self) -> str:
@ -887,11 +862,8 @@ class TestElasticsearch:
) )
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"] user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
match = re.match(pattern, user_agent)
assert ( assert (
match is not None re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern." ), f"The string '{user_agent}' does not match the expected pattern."
def test_elasticsearch_with_internal_user_agent( def test_elasticsearch_with_internal_user_agent(
@ -908,15 +880,12 @@ class TestElasticsearch:
) )
user_agent = store.client._headers["User-Agent"] user_agent = store.client._headers["User-Agent"]
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
match = re.match(pattern, user_agent)
assert ( assert (
match is not None re.match(r"^langchain-py-vs/\d+\.\d+\.\d+$", user_agent) is not None
), f"The string '{user_agent}' does not match the expected pattern." ), f"The string '{user_agent}' does not match the expected pattern."
def test_bulk_args(self, es_client: Any, index_name: str) -> None: def test_bulk_args(self, es_client: Any, index_name: str) -> None:
"""Test to make sure the user-agent is set correctly.""" """Test to make sure the bulk arguments work as expected."""
texts = ["foo", "bob", "baz"] texts = ["foo", "bob", "baz"]
ElasticsearchStore.from_texts( ElasticsearchStore.from_texts(