mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
partners: add Elasticsearch package (#17467)
### Description This PR moves the Elasticsearch classes to a partners package. Note that we will not move (and later remove) `ElasticKnnSearch`. It were previously deprecated. `ElasticVectorSearch` is going to stay in the community package since it is used quite a lot still. Also note that I left the `ElasticsearchTranslator` for self query untouched because it resides in main `langchain` package. ### Dependencies There will be another PR that updates the notebooks (potentially pulling them into the partners package) and templates and removes the classes from the community package, see https://github.com/langchain-ai/langchain/pull/17468 #### Open question How to make the transition smooth for users? Do we move the import aliases and require people to install `langchain-elasticsearch`? Or do we remove the import aliases from the `langchain` package all together? What has worked well for other partner packages? --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a4896da2a0
commit
5ab69f907f
3
.github/workflows/_integration_test.yml
vendored
3
.github/workflows/_integration_test.yml
vendored
@ -70,6 +70,9 @@ jobs:
|
||||
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
||||
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
||||
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
||||
ES_URL: ${{ secrets.ES_URL }}
|
||||
ES_CLOUD_ID: ${{ secrets.ES_CLOUD_ID }}
|
||||
ES_API_KEY: ${{ secrets.ES_API_KEY }}
|
||||
run: |
|
||||
make integration_tests
|
||||
|
||||
|
3
.github/workflows/_release.yml
vendored
3
.github/workflows/_release.yml
vendored
@ -191,6 +191,9 @@ jobs:
|
||||
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
||||
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
||||
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
||||
ES_URL: ${{ secrets.ES_URL }}
|
||||
ES_CLOUD_ID: ${{ secrets.ES_CLOUD_ID }}
|
||||
ES_API_KEY: ${{ secrets.ES_API_KEY }}
|
||||
run: make integration_tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
|
@ -1083,7 +1083,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()"
|
||||
|
@ -23,7 +23,7 @@ Elastic Cloud is a managed Elasticsearch service. Signup for a [free trial](http
|
||||
### Install Client
|
||||
|
||||
```bash
|
||||
pip install elasticsearch
|
||||
pip install langchain-elasticsearch
|
||||
```
|
||||
|
||||
## Vector Store
|
||||
@ -31,7 +31,7 @@ pip install elasticsearch
|
||||
The vector store is a simple wrapper around Elasticsearch. It provides a simple interface to store and retrieve vectors.
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import ElasticsearchStore
|
||||
from langchain_elasticsearch import ElasticsearchStore
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
@ -60,8 +60,8 @@
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
"from langchain_core.documents import Document\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
||||
|
@ -24,7 +24,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip -q install elasticsearch langchain"
|
||||
"!pip -q install langchain-elasticsearch"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -36,7 +36,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings"
|
||||
"from langchain_elasticsearch import ElasticsearchEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -21,7 +21,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet elasticsearch langchain-openai tiktoken langchain"
|
||||
"%pip install --upgrade --quiet langchain-elasticsearch langchain-openai tiktoken langchain"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -64,7 +64,7 @@
|
||||
"\n",
|
||||
"Example:\n",
|
||||
"```python\n",
|
||||
" from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
||||
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
" from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
" embedding = OpenAIEmbeddings()\n",
|
||||
@ -79,7 +79,7 @@
|
||||
"\n",
|
||||
"Example:\n",
|
||||
"```python\n",
|
||||
" from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
" from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
" embedding = OpenAIEmbeddings()\n",
|
||||
@ -97,7 +97,7 @@
|
||||
"Example:\n",
|
||||
"```python\n",
|
||||
" import elasticsearch\n",
|
||||
" from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"\n",
|
||||
" es_client= elasticsearch.Elasticsearch(\n",
|
||||
" hosts=[\"http://localhost:9200\"],\n",
|
||||
@ -137,7 +137,7 @@
|
||||
"\n",
|
||||
"Example:\n",
|
||||
"```python\n",
|
||||
" from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
||||
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
" from langchain_openai import OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
" embedding = OpenAIEmbeddings()\n",
|
||||
@ -202,7 +202,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_openai import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
@ -817,7 +817,7 @@
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"\n",
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"from langchain_core.documents import Document\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def custom_document_builder(hit: Dict) -> Document:\n",
|
||||
@ -902,7 +902,7 @@
|
||||
"\n",
|
||||
"```python\n",
|
||||
"\n",
|
||||
"from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"\n",
|
||||
"db = ElasticsearchStore(\n",
|
||||
" es_url=\"http://localhost:9200\",\n",
|
||||
@ -936,7 +936,7 @@
|
||||
"\n",
|
||||
"```python\n",
|
||||
"\n",
|
||||
"from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"\n",
|
||||
"db = ElasticsearchStore(\n",
|
||||
" es_url=\"http://localhost:9200\",\n",
|
||||
|
@ -91,8 +91,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.indexes import SQLRecordManager, index\n",
|
||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
||||
"from langchain_core.documents import Document\n",
|
||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||
"from langchain_openai import OpenAIEmbeddings"
|
||||
]
|
||||
},
|
||||
|
1
libs/partners/elasticsearch/.gitignore
vendored
Normal file
1
libs/partners/elasticsearch/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/elasticsearch/LICENSE
Normal file
21
libs/partners/elasticsearch/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
60
libs/partners/elasticsearch/Makefile
Normal file
60
libs/partners/elasticsearch/Makefile
Normal file
@ -0,0 +1,60 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
install:
|
||||
poetry install
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/elasticsearch --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_elasticsearch
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_elasticsearch -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
29
libs/partners/elasticsearch/README.md
Normal file
29
libs/partners/elasticsearch/README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# langchain-elasticsearch
|
||||
|
||||
This package contains the LangChain integration with Elasticsearch.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-elasticsearch
|
||||
```
|
||||
|
||||
TODO document how to get id and key
|
||||
|
||||
## Usage
|
||||
|
||||
The `ElasticsearchStore` class exposes the connection to the Pinecone vector store.
|
||||
|
||||
```python
|
||||
from langchain_elasticsearch import ElasticsearchStore
|
||||
|
||||
embeddings = ... # use a LangChain Embeddings class
|
||||
|
||||
vectorstore = ElasticsearchStore(
|
||||
es_cloud_id="your-cloud-id",
|
||||
es_api_key="your-api-key",
|
||||
index_name="your-index-name",
|
||||
embeddings=embeddings,
|
||||
)
|
||||
```
|
||||
|
@ -0,0 +1,17 @@
|
||||
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||
from langchain_elasticsearch.vectorstores import (
|
||||
ApproxRetrievalStrategy,
|
||||
ElasticsearchStore,
|
||||
ExactRetrievalStrategy,
|
||||
SparseRetrievalStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ApproxRetrievalStrategy",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"ElasticsearchEmbeddings",
|
||||
"ElasticsearchStore",
|
||||
"ExactRetrievalStrategy",
|
||||
"SparseRetrievalStrategy",
|
||||
]
|
@ -0,0 +1,82 @@
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
class DistanceStrategy(str, Enum):
|
||||
"""Enumerator of the Distance strategies for calculating distances
|
||||
between vectors."""
|
||||
|
||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
JACCARD = "JACCARD"
|
||||
COSINE = "COSINE"
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
"""Calculate maximal marginal relevance."""
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
if query_embedding.ndim == 1:
|
||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||
most_similar = int(np.argmax(similarity_to_query))
|
||||
idxs = [most_similar]
|
||||
selected = np.array([embedding_list[most_similar]])
|
||||
while len(idxs) < min(k, len(embedding_list)):
|
||||
best_score = -np.inf
|
||||
idx_to_add = -1
|
||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||
for i, query_score in enumerate(similarity_to_query):
|
||||
if i in idxs:
|
||||
continue
|
||||
redundant_score = max(similarity_to_selected[i])
|
||||
equation_score = (
|
||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||
)
|
||||
if equation_score > best_score:
|
||||
best_score = equation_score
|
||||
idx_to_add = i
|
||||
idxs.append(idx_to_add)
|
||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||
return idxs
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
try:
|
||||
import simsimd as simd # type: ignore
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
||||
if isinstance(Z, float):
|
||||
return np.array([Z])
|
||||
return Z
|
||||
except ImportError:
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
@ -0,0 +1,201 @@
|
||||
import json
|
||||
import logging
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in Elasticsearch.
|
||||
|
||||
Args:
|
||||
es_url: URL of the Elasticsearch instance to connect to.
|
||||
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||
es_user: Username to use when connecting to Elasticsearch.
|
||||
es_password: Password to use when connecting to Elasticsearch.
|
||||
es_api_key: API key to use when connecting to Elasticsearch.
|
||||
es_connection: Optional pre-existing Elasticsearch connection.
|
||||
esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True.
|
||||
index: Name of the index to use.
|
||||
session_id: Arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: str,
|
||||
session_id: str,
|
||||
*,
|
||||
es_connection: Optional["Elasticsearch"] = None,
|
||||
es_url: Optional[str] = None,
|
||||
es_cloud_id: Optional[str] = None,
|
||||
es_user: Optional[str] = None,
|
||||
es_api_key: Optional[str] = None,
|
||||
es_password: Optional[str] = None,
|
||||
esnsure_ascii: Optional[bool] = True,
|
||||
):
|
||||
self.index: str = index
|
||||
self.session_id: str = session_id
|
||||
self.ensure_ascii = esnsure_ascii
|
||||
|
||||
# Initialize Elasticsearch client from passed client arg or connection info
|
||||
if es_connection is not None:
|
||||
self.client = es_connection.options(
|
||||
headers={"user-agent": self.get_user_agent()}
|
||||
)
|
||||
elif es_url is not None or es_cloud_id is not None:
|
||||
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||
es_url=es_url,
|
||||
username=es_user,
|
||||
password=es_password,
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Either provide a pre-existing Elasticsearch connection, \
|
||||
or valid credentials for creating a new connection."""
|
||||
)
|
||||
|
||||
if self.client.indices.exists(index=index):
|
||||
logger.debug(
|
||||
f"Chat history index {index} already exists, skipping creation."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Creating index {index} for storing chat history.")
|
||||
|
||||
self.client.indices.create(
|
||||
index=index,
|
||||
mappings={
|
||||
"properties": {
|
||||
"session_id": {"type": "keyword"},
|
||||
"created_at": {"type": "date"},
|
||||
"history": {"type": "text"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_core import __version__
|
||||
|
||||
return f"langchain-py-ms/{__version__}"
|
||||
|
||||
@staticmethod
|
||||
def connect_to_elasticsearch(
|
||||
*,
|
||||
es_url: Optional[str] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> "Elasticsearch":
|
||||
try:
|
||||
import elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
|
||||
if es_url and cloud_id:
|
||||
raise ValueError(
|
||||
"Both es_url and cloud_id are defined. Please provide only one."
|
||||
)
|
||||
|
||||
connection_params: Dict[str, Any] = {}
|
||||
|
||||
if es_url:
|
||||
connection_params["hosts"] = [es_url]
|
||||
elif cloud_id:
|
||||
connection_params["cloud_id"] = cloud_id
|
||||
else:
|
||||
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||
|
||||
if api_key:
|
||||
connection_params["api_key"] = api_key
|
||||
elif username and password:
|
||||
connection_params["basic_auth"] = (username, password)
|
||||
|
||||
es_client = elasticsearch.Elasticsearch(
|
||||
**connection_params,
|
||||
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||
)
|
||||
try:
|
||||
es_client.info()
|
||||
except Exception as err:
|
||||
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
return es_client
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||
"""Retrieve the messages from Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
result = self.client.search(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
sort="created_at:asc",
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
if result and len(result["hits"]["hits"]) > 0:
|
||||
items = [
|
||||
json.loads(document["_source"]["history"])
|
||||
for document in result["hits"]["hits"]
|
||||
]
|
||||
else:
|
||||
items = []
|
||||
|
||||
return messages_from_dict(items)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the chat session in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.index(
|
||||
index=self.index,
|
||||
document={
|
||||
"session_id": self.session_id,
|
||||
"created_at": round(time() * 1000),
|
||||
"history": json.dumps(
|
||||
message_to_dict(message),
|
||||
ensure_ascii=bool(self.ensure_ascii),
|
||||
),
|
||||
},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||
raise err
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory in Elasticsearch"""
|
||||
try:
|
||||
from elasticsearch import ApiError
|
||||
|
||||
self.client.delete_by_query(
|
||||
index=self.index,
|
||||
query={"term": {"session_id": self.session_id}},
|
||||
refresh=True,
|
||||
)
|
||||
except ApiError as err:
|
||||
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||
raise err
|
@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch.client import MlClient
|
||||
|
||||
|
||||
class ElasticsearchEmbeddings(Embeddings):
|
||||
"""Elasticsearch embedding models.
|
||||
|
||||
This class provides an interface to generate embeddings using a model deployed
|
||||
in an Elasticsearch cluster. It requires an Elasticsearch connection object
|
||||
and the model_id of the model deployed in the cluster.
|
||||
|
||||
In Elasticsearch you need to have an embedding model loaded and deployed.
|
||||
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
|
||||
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MlClient,
|
||||
model_id: str,
|
||||
*,
|
||||
input_field: str = "text_field",
|
||||
):
|
||||
"""
|
||||
Initialize the ElasticsearchEmbeddings instance.
|
||||
|
||||
Args:
|
||||
client (MlClient): An Elasticsearch ML client object.
|
||||
model_id (str): The model_id of the model deployed in the Elasticsearch
|
||||
cluster.
|
||||
input_field (str): The name of the key for the input text field in the
|
||||
document. Defaults to 'text_field'.
|
||||
"""
|
||||
self.client = client
|
||||
self.model_id = model_id
|
||||
self.input_field = input_field
|
||||
|
||||
@classmethod
|
||||
def from_credentials(
|
||||
cls,
|
||||
model_id: str,
|
||||
*,
|
||||
es_cloud_id: Optional[str] = None,
|
||||
es_api_key: Optional[str] = None,
|
||||
input_field: str = "text_field",
|
||||
) -> ElasticsearchEmbeddings:
|
||||
"""Instantiate embeddings from Elasticsearch credentials.
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id of the model deployed in the Elasticsearch
|
||||
cluster.
|
||||
input_field (str): The name of the key for the input text field in the
|
||||
document. Defaults to 'text_field'.
|
||||
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
|
||||
es_user: (str, optional): Elasticsearch username.
|
||||
es_password: (str, optional): Elasticsearch password.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_elasticserach.embeddings import ElasticsearchEmbeddings
|
||||
|
||||
# Define the model ID and input field name (if different from default)
|
||||
model_id = "your_model_id"
|
||||
# Optional, only if different from 'text_field'
|
||||
input_field = "your_input_field"
|
||||
|
||||
# Credentials can be passed in two ways. Either set the env vars
|
||||
# ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically
|
||||
# pulled in, or pass them in directly as kwargs.
|
||||
embeddings = ElasticsearchEmbeddings.from_credentials(
|
||||
model_id,
|
||||
input_field=input_field,
|
||||
# es_cloud_id="foo",
|
||||
# es_user="bar",
|
||||
# es_password="baz",
|
||||
)
|
||||
|
||||
documents = [
|
||||
"This is an example document.",
|
||||
"Another example document to generate embeddings for.",
|
||||
]
|
||||
embeddings_generator.embed_documents(documents)
|
||||
"""
|
||||
from elasticsearch.client import MlClient
|
||||
|
||||
es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID")
|
||||
es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY")
|
||||
|
||||
# Connect to Elasticsearch
|
||||
es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key)
|
||||
client = MlClient(es_connection)
|
||||
return cls(client, model_id, input_field=input_field)
|
||||
|
||||
@classmethod
|
||||
def from_es_connection(
|
||||
cls,
|
||||
model_id: str,
|
||||
es_connection: Elasticsearch,
|
||||
input_field: str = "text_field",
|
||||
) -> ElasticsearchEmbeddings:
|
||||
"""
|
||||
Instantiate embeddings from an existing Elasticsearch connection.
|
||||
|
||||
This method provides a way to create an instance of the ElasticsearchEmbeddings
|
||||
class using an existing Elasticsearch connection. The connection object is used
|
||||
to create an MlClient, which is then used to initialize the
|
||||
ElasticsearchEmbeddings instance.
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
|
||||
es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
|
||||
connection object. input_field (str, optional): The name of the key for the
|
||||
input text field in the document. Defaults to 'text_field'.
|
||||
|
||||
Returns:
|
||||
ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||
|
||||
# Define the model ID and input field name (if different from default)
|
||||
model_id = "your_model_id"
|
||||
# Optional, only if different from 'text_field'
|
||||
input_field = "your_input_field"
|
||||
|
||||
# Create Elasticsearch connection
|
||||
es_connection = Elasticsearch(
|
||||
hosts=["localhost:9200"], http_auth=("user", "password")
|
||||
)
|
||||
|
||||
# Instantiate ElasticsearchEmbeddings using the existing connection
|
||||
embeddings = ElasticsearchEmbeddings.from_es_connection(
|
||||
model_id,
|
||||
es_connection,
|
||||
input_field=input_field,
|
||||
)
|
||||
|
||||
documents = [
|
||||
"This is an example document.",
|
||||
"Another example document to generate embeddings for.",
|
||||
]
|
||||
embeddings_generator.embed_documents(documents)
|
||||
"""
|
||||
from elasticsearch.client import MlClient
|
||||
|
||||
# Create an MlClient from the given Elasticsearch connection
|
||||
client = MlClient(es_connection)
|
||||
|
||||
# Return a new instance of the ElasticsearchEmbeddings class with
|
||||
# the MlClient, model_id, and input_field
|
||||
return cls(client, model_id, input_field=input_field)
|
||||
|
||||
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for the given texts using the Elasticsearch model.
|
||||
|
||||
Args:
|
||||
texts (List[str]): A list of text strings to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings, one for each text in the input
|
||||
list.
|
||||
"""
|
||||
response = self.client.infer_trained_model(
|
||||
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
|
||||
)
|
||||
|
||||
embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a list of documents.
|
||||
|
||||
Args:
|
||||
texts (List[str]): A list of document text strings to generate embeddings
|
||||
for.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings, one for each document in the input
|
||||
list.
|
||||
"""
|
||||
return self._embedding_func(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding for a single query text.
|
||||
|
||||
Args:
|
||||
text (str): The query text to generate an embedding for.
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding for the input query text.
|
||||
"""
|
||||
return self._embedding_func([text])[0]
|
1285
libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py
Normal file
1285
libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py
Normal file
File diff suppressed because it is too large
Load Diff
1655
libs/partners/elasticsearch/poetry.lock
generated
Normal file
1655
libs/partners/elasticsearch/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
96
libs/partners/elasticsearch/pyproject.toml
Normal file
96
libs/partners/elasticsearch/pyproject.toml
Normal file
@ -0,0 +1,96 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-elasticsearch"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Elasticsearch and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/elasticsearch"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1"
|
||||
elasticsearch = "^8.12.0"
|
||||
numpy = "^1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain = { path = "../../langchain", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/elasticsearch/scripts/check_imports.py
Normal file
17
libs/partners/elasticsearch/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/elasticsearch/scripts/check_pydantic.sh
Executable file
27
libs/partners/elasticsearch/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/elasticsearch/scripts/lint_imports.sh
Executable file
17
libs/partners/elasticsearch/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/elasticsearch/tests/__init__.py
Normal file
0
libs/partners/elasticsearch/tests/__init__.py
Normal file
55
libs/partners/elasticsearch/tests/fake_embeddings.py
Normal file
55
libs/partners/elasticsearch/tests/fake_embeddings.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Fake Embedding class for testing purposes."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||
vectors for the same texts."""
|
||||
|
||||
def __init__(self, dimensionality: int = 10) -> None:
|
||||
self.known_texts: List[str] = []
|
||||
self.dimensionality = dimensionality
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
if text not in self.known_texts:
|
||||
self.known_texts.append(text)
|
||||
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||
float(self.known_texts.index(text))
|
||||
]
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
@ -0,0 +1,89 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/memory/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_USERNAME
|
||||
- ES_PASSWORD
|
||||
"""
|
||||
|
||||
|
||||
class TestElasticsearch:
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||
# Run this integration test against Elasticsearch on localhost,
|
||||
# or an Elastic Cloud instance
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
es_api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if es_cloud_id:
|
||||
es = Elasticsearch(
|
||||
cloud_id=es_cloud_id,
|
||||
api_key=es_api_key,
|
||||
)
|
||||
yield {
|
||||
"es_cloud_id": es_cloud_id,
|
||||
"es_api_key": es_api_key,
|
||||
}
|
||||
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_memory_with_message_store(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Elasticsearch as a message store
|
||||
message_history = ElasticsearchChatMessageHistory(
|
||||
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,48 @@
|
||||
"""Test elasticsearch_embeddings embeddings."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||
|
||||
# deployed with
|
||||
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
|
||||
DEFAULT_MODEL = "sentence-transformers__msmarco-minilm-l-12-v3"
|
||||
DEFAULT_NUM_DIMENSIONS = "384"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> str:
|
||||
return get_from_env("model_id", "MODEL_ID", DEFAULT_MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_num_dimensions() -> int:
|
||||
return int(
|
||||
get_from_env(
|
||||
"expected_num_dimensions", "EXPECTED_NUM_DIMENSIONS", DEFAULT_NUM_DIMENSIONS
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_elasticsearch_embedding_documents(
|
||||
model_id: str, expected_num_dimensions: int
|
||||
) -> None:
|
||||
"""Test Elasticsearch embedding documents."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == expected_num_dimensions
|
||||
assert len(output[1]) == expected_num_dimensions
|
||||
assert len(output[2]) == expected_num_dimensions
|
||||
|
||||
|
||||
def test_elasticsearch_embedding_query(
|
||||
model_id: str, expected_num_dimensions: int
|
||||
) -> None:
|
||||
"""Test Elasticsearch embedding query."""
|
||||
document = "foo bar"
|
||||
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == expected_num_dimensions
|
@ -0,0 +1,931 @@
|
||||
"""Test ElasticsearchStore functionality."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
import pytest
|
||||
from elastic_transport import Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
||||
|
||||
from ..fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/vectorstores/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_API_KEY
|
||||
|
||||
Some of the tests require the following models to be deployed in the ML Node:
|
||||
- elser (can be downloaded and deployed through Kibana and trained models UI)
|
||||
- sentence-transformers__all-minilm-l6-v2 (can be deployed
|
||||
through API, loaded via eland)
|
||||
|
||||
These tests that require the models to be deployed are skipped by default.
|
||||
Enable them by adding the model name to the modelsDeployed list below.
|
||||
"""
|
||||
|
||||
modelsDeployed: List[str] = [
|
||||
# "elser",
|
||||
# "sentence-transformers__all-minilm-l6-v2",
|
||||
]
|
||||
|
||||
|
||||
class TestElasticsearch:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if cloud_id:
|
||||
# Running this integration test with Elastic Cloud
|
||||
# Required for in-stack inference testing (ELSER + model_id)
|
||||
es = Elasticsearch(
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
)
|
||||
yield {
|
||||
"es_cloud_id": cloud_id,
|
||||
"es_api_key": api_key,
|
||||
}
|
||||
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
# clear all test pipelines
|
||||
try:
|
||||
response = es.ingest.get_pipeline(id="test_*,*_sparse_embedding")
|
||||
|
||||
for pipeline_id, _ in response.items():
|
||||
try:
|
||||
es.ingest.delete_pipeline(id=pipeline_id)
|
||||
print(f"Deleted pipeline: {pipeline_id}") # noqa: T201
|
||||
except Exception as e:
|
||||
print(f"Pipeline error: {e}") # noqa: T201
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def es_client(self) -> Any:
|
||||
class CustomTransport(Transport):
|
||||
requests = []
|
||||
|
||||
def perform_request(self, *args, **kwargs): # type: ignore
|
||||
self.requests.append(kwargs)
|
||||
return super().perform_request(*args, **kwargs)
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
api_key = os.environ.get("ES_API_KEY")
|
||||
|
||||
if cloud_id:
|
||||
# Running this integration test with Elastic Cloud
|
||||
# Required for in-stack inference testing (ELSER + model_id)
|
||||
es = Elasticsearch(
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
transport_class=CustomTransport,
|
||||
)
|
||||
return es
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url, transport_class=CustomTransport)
|
||||
return es
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_similarity_search_without_metadata(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search without metadata."""
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 1,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
}
|
||||
}
|
||||
return query_body
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
async def test_similarity_search_without_metadata_async(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search without metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
output = await docsearch.asimilarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_add_embeddings(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Test add_embeddings, which accepts pre-built embeddings instead of
|
||||
using inference for the texts.
|
||||
This allows you to separate the embeddings text and the page_content
|
||||
for better proximity between user's question and embedded text.
|
||||
For example, your embedding text can be a question, whereas page_content
|
||||
is the answer.
|
||||
"""
|
||||
embeddings = ConsistentFakeEmbeddings()
|
||||
text_input = ["foo1", "foo2", "foo3"]
|
||||
metadatas = [{"page": i} for i in range(len(text_input))]
|
||||
|
||||
"""In real use case, embedding_input can be questions for each text"""
|
||||
embedding_input = ["foo2", "foo3", "foo1"]
|
||||
embedding_vectors = embeddings.embed_documents(embedding_input)
|
||||
|
||||
docsearch = ElasticsearchStore._create_cls_from_kwargs(
|
||||
embeddings,
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
docsearch.add_embeddings(list(zip(text_input, embedding_vectors)), metadatas)
|
||||
output = docsearch.similarity_search("foo1", k=1)
|
||||
assert output == [Document(page_content="foo3", metadata={"page": 2})]
|
||||
|
||||
def test_similarity_search_with_metadata(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
output = docsearch.similarity_search("bar", k=1)
|
||||
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||
|
||||
def test_similarity_search_with_filter(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "foo", "foo"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [{"term": {"metadata.page": "1"}}],
|
||||
"k": 3,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
}
|
||||
}
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search(
|
||||
query="foo",
|
||||
k=3,
|
||||
filter=[{"term": {"metadata.page": "1"}}],
|
||||
custom_query=assert_query,
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 1})]
|
||||
|
||||
def test_similarity_search_with_doc_builder(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
texts = ["foo", "foo", "foo"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
def custom_document_builder(_: Dict) -> Document:
|
||||
return Document(
|
||||
page_content="Mock content!",
|
||||
metadata={
|
||||
"page_number": -1,
|
||||
"original_filename": "Mock filename!",
|
||||
},
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search(
|
||||
query="foo", k=1, doc_builder=custom_document_builder
|
||||
)
|
||||
assert output[0].page_content == "Mock content!"
|
||||
assert output[0].metadata["page_number"] == -1
|
||||
assert output[0].metadata["original_filename"] == "Mock filename!"
|
||||
|
||||
def test_similarity_search_exact_search(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
)
|
||||
|
||||
expected_query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
||||
"params": {
|
||||
"query_vector": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
0.0,
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == expected_query
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_exact_search_with_filter(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
metadatas=metadatas,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
)
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
expected_query = {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
||||
"params": {
|
||||
"query_vector": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
0.0,
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
assert query_body == expected_query
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search(
|
||||
"foo",
|
||||
k=1,
|
||||
custom_query=assert_query,
|
||||
filter=[{"term": {"metadata.page": 0}}],
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||
|
||||
def test_similarity_search_exact_search_distance_dot_product(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
distance_strategy="DOT_PRODUCT",
|
||||
)
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == {
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": """
|
||||
double value = dotProduct(params.query_vector, 'vector');
|
||||
return sigmoid(1, Math.E, -value);
|
||||
""",
|
||||
"params": {
|
||||
"query_vector": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
0.0,
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_exact_search_unknown_distance_strategy(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with unknown distance strategy."""
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
texts = ["foo", "bar", "baz"]
|
||||
ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
distance_strategy="NOT_A_STRATEGY",
|
||||
)
|
||||
|
||||
def test_max_marginal_relevance_search(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||
)
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
|
||||
sim_output = docsearch.similarity_search(texts[0], k=3)
|
||||
assert mmr_output == sim_output
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[1]
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(
|
||||
texts[0],
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
lambda_mult=0.1, # more diversity
|
||||
)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[2]
|
||||
|
||||
# if fetch_k < k, then the output will be less than k
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
|
||||
assert len(mmr_output) == 2
|
||||
|
||||
def test_similarity_search_approx_with_hybrid_search(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||
)
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 1,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
},
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [],
|
||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||
}
|
||||
},
|
||||
"rank": {"rrf": {}},
|
||||
}
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_similarity_search_approx_with_hybrid_search_rrf(
|
||||
self, es_client: Any, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test end to end construction and rrf hybrid search with metadata."""
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
# 1. check query_body is okay
|
||||
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
|
||||
True,
|
||||
False,
|
||||
{"rank_constant": 1, "window_size": 5},
|
||||
]
|
||||
for rrf_test_case in rrf_test_cases:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||
hybrid=True, rrf=rrf_test_case
|
||||
),
|
||||
)
|
||||
|
||||
def assert_query(
|
||||
query_body: dict,
|
||||
query: str,
|
||||
rrf: Optional[Union[dict, bool]] = True,
|
||||
) -> dict:
|
||||
cmp_query_body = {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 3,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
0.0,
|
||||
],
|
||||
},
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [],
|
||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(rrf, dict):
|
||||
cmp_query_body["rank"] = {"rrf": rrf}
|
||||
elif isinstance(rrf, bool) and rrf is True:
|
||||
cmp_query_body["rank"] = {"rrf": {}}
|
||||
|
||||
assert query_body == cmp_query_body
|
||||
|
||||
return query_body
|
||||
|
||||
## without fetch_k parameter
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
|
||||
)
|
||||
|
||||
# 2. check query result is okay
|
||||
es_output = es_client.search(
|
||||
index=index_name,
|
||||
query={
|
||||
"bool": {
|
||||
"filter": [],
|
||||
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||
}
|
||||
},
|
||||
knn={
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 3,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
},
|
||||
size=3,
|
||||
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
|
||||
)
|
||||
|
||||
assert [o.page_content for o in output] == [
|
||||
e["_source"]["text"] for e in es_output["hits"]["hits"]
|
||||
]
|
||||
|
||||
# 3. check rrf default option is okay
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||
)
|
||||
|
||||
## with fetch_k parameter
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=3, fetch_k=50, custom_query=assert_query
|
||||
)
|
||||
|
||||
def test_similarity_search_approx_with_custom_query_fn(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""test that custom query function is called
|
||||
with the query string and query body"""
|
||||
|
||||
def my_custom_query(query_body: dict, query: str) -> dict:
|
||||
assert query == "foo"
|
||||
assert query_body == {
|
||||
"knn": {
|
||||
"field": "vector",
|
||||
"filter": [],
|
||||
"k": 1,
|
||||
"num_candidates": 50,
|
||||
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||
}
|
||||
}
|
||||
return {"query": {"match": {"text": {"query": "bar"}}}}
|
||||
|
||||
"""Test end to end construction and search with metadata."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts, FakeEmbeddings(), **elasticsearch_connection, index_name=index_name
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=my_custom_query)
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
"sentence-transformers__all-minilm-l6-v2" not in modelsDeployed,
|
||||
reason="Sentence Transformers model not deployed in ML Node, skipping test",
|
||||
)
|
||||
def test_similarity_search_with_approx_infer_instack(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""test end to end with approx retrieval strategy and inference in-stack"""
|
||||
docsearch = ElasticsearchStore(
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||
query_model_id="sentence-transformers__all-minilm-l6-v2"
|
||||
),
|
||||
query_field="text_field",
|
||||
vector_query_field="vector_query_field.predicted_value",
|
||||
**elasticsearch_connection,
|
||||
)
|
||||
|
||||
# setting up the pipeline for inference
|
||||
docsearch.client.ingest.put_pipeline(
|
||||
id="test_pipeline",
|
||||
processors=[
|
||||
{
|
||||
"inference": {
|
||||
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
||||
"field_map": {"query_field": "text_field"},
|
||||
"target_field": "vector_query_field",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# creating a new index with the pipeline,
|
||||
# not relying on langchain to create the index
|
||||
docsearch.client.indices.create(
|
||||
index=index_name,
|
||||
mappings={
|
||||
"properties": {
|
||||
"text_field": {"type": "text"},
|
||||
"vector_query_field": {
|
||||
"properties": {
|
||||
"predicted_value": {
|
||||
"type": "dense_vector",
|
||||
"dims": 384,
|
||||
"index": True,
|
||||
"similarity": "l2_norm",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
settings={"index": {"default_pipeline": "test_pipeline"}},
|
||||
)
|
||||
|
||||
# adding documents to the index
|
||||
texts = ["foo", "bar", "baz"]
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
docsearch.client.create(
|
||||
index=index_name,
|
||||
id=str(i),
|
||||
document={"text_field": text, "metadata": {}},
|
||||
)
|
||||
|
||||
docsearch.client.indices.refresh(index=index_name)
|
||||
|
||||
def assert_query(query_body: dict, query: str) -> dict:
|
||||
assert query_body == {
|
||||
"knn": {
|
||||
"filter": [],
|
||||
"field": "vector_query_field.predicted_value",
|
||||
"k": 1,
|
||||
"num_candidates": 50,
|
||||
"query_vector_builder": {
|
||||
"text_embedding": {
|
||||
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
||||
"model_text": "foo",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
return query_body
|
||||
|
||||
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
output = docsearch.similarity_search("bar", k=1)
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
"elser" not in modelsDeployed,
|
||||
reason="ELSER not deployed in ML Node, skipping test",
|
||||
)
|
||||
def test_similarity_search_with_sparse_infer_instack(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""test end to end with sparse retrieval strategy and inference in-stack"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_elasticsearch_with_relevance_score(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
embeddings = FakeEmbeddings()
|
||||
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
index_name=index_name,
|
||||
texts=texts,
|
||||
embedding=embeddings,
|
||||
metadatas=metadatas,
|
||||
**elasticsearch_connection,
|
||||
)
|
||||
|
||||
embedded_query = embeddings.embed_query("foo")
|
||||
output = docsearch.similarity_search_by_vector_with_relevance_scores(
|
||||
embedding=embedded_query, k=1
|
||||
)
|
||||
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
|
||||
|
||||
def test_elasticsearch_with_relevance_threshold(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test to make sure the relevance threshold is respected."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
embeddings = FakeEmbeddings()
|
||||
|
||||
docsearch = ElasticsearchStore.from_texts(
|
||||
index_name=index_name,
|
||||
texts=texts,
|
||||
embedding=embeddings,
|
||||
metadatas=metadatas,
|
||||
**elasticsearch_connection,
|
||||
)
|
||||
|
||||
# Find a good threshold for testing
|
||||
query_string = "foo"
|
||||
embedded_query = embeddings.embed_query(query_string)
|
||||
top3 = docsearch.similarity_search_by_vector_with_relevance_scores(
|
||||
embedding=embedded_query, k=3
|
||||
)
|
||||
similarity_of_second_ranked = top3[1][1]
|
||||
assert len(top3) == 3
|
||||
|
||||
# Test threshold
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"score_threshold": similarity_of_second_ranked},
|
||||
)
|
||||
output = retriever.get_relevant_documents(query=query_string)
|
||||
|
||||
assert output == [
|
||||
top3[0][0],
|
||||
top3[1][0],
|
||||
# third ranked is out
|
||||
]
|
||||
|
||||
def test_elasticsearch_delete_ids(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test delete methods from vector store."""
|
||||
texts = ["foo", "bar", "baz", "gni"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = ElasticsearchStore(
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
ids = docsearch.add_texts(texts, metadatas)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 4
|
||||
|
||||
docsearch.delete(ids[1:3])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
|
||||
docsearch.delete(["not-existing"])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 2
|
||||
|
||||
docsearch.delete([ids[0]])
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 1
|
||||
|
||||
docsearch.delete([ids[3]])
|
||||
output = docsearch.similarity_search("gni", k=10)
|
||||
assert len(output) == 0
|
||||
|
||||
def test_elasticsearch_indexing_exception_error(
|
||||
self,
|
||||
elasticsearch_connection: dict,
|
||||
index_name: str,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test bulk exception logging is giving better hints."""
|
||||
|
||||
docsearch = ElasticsearchStore(
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
docsearch.client.indices.create(
|
||||
index=index_name,
|
||||
mappings={"properties": {}},
|
||||
settings={"index": {"default_pipeline": "not-existing-pipeline"}},
|
||||
)
|
||||
|
||||
texts = ["foo"]
|
||||
|
||||
with pytest.raises(BulkIndexError):
|
||||
docsearch.add_texts(texts)
|
||||
|
||||
error_reason = "pipeline with id [not-existing-pipeline] does not exist"
|
||||
log_message = f"First error reason: {error_reason}"
|
||||
|
||||
assert log_message in caplog.text
|
||||
|
||||
def test_elasticsearch_with_user_agent(
|
||||
self, es_client: Any, index_name: str
|
||||
) -> None:
|
||||
"""Test to make sure the user-agent is set correctly."""
|
||||
|
||||
texts = ["foo", "bob", "baz"]
|
||||
ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
es_connection=es_client,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
|
||||
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||
match = re.match(pattern, user_agent)
|
||||
|
||||
assert (
|
||||
match is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_elasticsearch_with_internal_user_agent(
|
||||
self, elasticsearch_connection: Dict, index_name: str
|
||||
) -> None:
|
||||
"""Test to make sure the user-agent is set correctly."""
|
||||
|
||||
texts = ["foo"]
|
||||
store = ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
)
|
||||
|
||||
user_agent = store.client._headers["User-Agent"]
|
||||
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||
match = re.match(pattern, user_agent)
|
||||
|
||||
assert (
|
||||
match is not None
|
||||
), f"The string '{user_agent}' does not match the expected pattern."
|
||||
|
||||
def test_bulk_args(self, es_client: Any, index_name: str) -> None:
|
||||
"""Test to make sure the user-agent is set correctly."""
|
||||
|
||||
texts = ["foo", "bob", "baz"]
|
||||
ElasticsearchStore.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
es_connection=es_client,
|
||||
index_name=index_name,
|
||||
bulk_kwargs={"chunk_size": 1},
|
||||
)
|
||||
|
||||
# 1 for index exist, 1 for index create, 3 for index docs
|
||||
assert len(es_client.transport.requests) == 5 # type: ignore
|
14
libs/partners/elasticsearch/tests/unit_tests/test_imports.py
Normal file
14
libs/partners/elasticsearch/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,14 @@
|
||||
from langchain_elasticsearch import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ApproxRetrievalStrategy",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"ElasticsearchEmbeddings",
|
||||
"ElasticsearchStore",
|
||||
"ExactRetrievalStrategy",
|
||||
"SparseRetrievalStrategy",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,34 @@
|
||||
"""Test Elasticsearch functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_elasticsearch.vectorstores import (
|
||||
ApproxRetrievalStrategy,
|
||||
ElasticsearchStore,
|
||||
)
|
||||
|
||||
from ..fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
@pytest.mark.requires("elasticsearch")
|
||||
def test_elasticsearch_hybrid_scores_guard() -> None:
|
||||
"""Ensure an error is raised when search with score in hybrid mode
|
||||
because in this case Elasticsearch does not return any score.
|
||||
"""
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
query_string = "foo"
|
||||
embeddings = FakeEmbeddings()
|
||||
|
||||
store = ElasticsearchStore(
|
||||
index_name="dummy_index",
|
||||
es_connection=Elasticsearch(hosts=["http://dummy-host:9200"]),
|
||||
embedding=embeddings,
|
||||
strategy=ApproxRetrievalStrategy(hybrid=True),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
store.similarity_search_with_score(query_string)
|
||||
|
||||
embedded_query = embeddings.embed_query(query_string)
|
||||
with pytest.raises(ValueError):
|
||||
store.similarity_search_by_vector_with_relevance_scores(embedded_query)
|
Loading…
Reference in New Issue
Block a user