mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 11:31:58 +00:00
partners: add Elasticsearch package (#17467)
### Description This PR moves the Elasticsearch classes to a partners package. Note that we will not move (and later remove) `ElasticKnnSearch`. It were previously deprecated. `ElasticVectorSearch` is going to stay in the community package since it is used quite a lot still. Also note that I left the `ElasticsearchTranslator` for self query untouched because it resides in main `langchain` package. ### Dependencies There will be another PR that updates the notebooks (potentially pulling them into the partners package) and templates and removes the classes from the community package, see https://github.com/langchain-ai/langchain/pull/17468 #### Open question How to make the transition smooth for users? Do we move the import aliases and require people to install `langchain-elasticsearch`? Or do we remove the import aliases from the `langchain` package all together? What has worked well for other partner packages? --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a4896da2a0
commit
5ab69f907f
3
.github/workflows/_integration_test.yml
vendored
3
.github/workflows/_integration_test.yml
vendored
@ -70,6 +70,9 @@ jobs:
|
|||||||
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
||||||
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
||||||
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
||||||
|
ES_URL: ${{ secrets.ES_URL }}
|
||||||
|
ES_CLOUD_ID: ${{ secrets.ES_CLOUD_ID }}
|
||||||
|
ES_API_KEY: ${{ secrets.ES_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
make integration_tests
|
make integration_tests
|
||||||
|
|
||||||
|
3
.github/workflows/_release.yml
vendored
3
.github/workflows/_release.yml
vendored
@ -191,6 +191,9 @@ jobs:
|
|||||||
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
|
||||||
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
|
||||||
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
|
||||||
|
ES_URL: ${{ secrets.ES_URL }}
|
||||||
|
ES_CLOUD_ID: ${{ secrets.ES_CLOUD_ID }}
|
||||||
|
ES_API_KEY: ${{ secrets.ES_API_KEY }}
|
||||||
run: make integration_tests
|
run: make integration_tests
|
||||||
working-directory: ${{ inputs.working-directory }}
|
working-directory: ${{ inputs.working-directory }}
|
||||||
|
|
||||||
|
@ -1083,7 +1083,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"from langchain_openai import OpenAIEmbeddings\n",
|
"from langchain_openai import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
"embeddings = OpenAIEmbeddings()"
|
"embeddings = OpenAIEmbeddings()"
|
||||||
|
@ -23,7 +23,7 @@ Elastic Cloud is a managed Elasticsearch service. Signup for a [free trial](http
|
|||||||
### Install Client
|
### Install Client
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install elasticsearch
|
pip install langchain-elasticsearch
|
||||||
```
|
```
|
||||||
|
|
||||||
## Vector Store
|
## Vector Store
|
||||||
@ -31,7 +31,7 @@ pip install elasticsearch
|
|||||||
The vector store is a simple wrapper around Elasticsearch. It provides a simple interface to store and retrieve vectors.
|
The vector store is a simple wrapper around Elasticsearch. It provides a simple interface to store and retrieve vectors.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain_community.vectorstores import ElasticsearchStore
|
from langchain_elasticsearch import ElasticsearchStore
|
||||||
|
|
||||||
from langchain_community.document_loaders import TextLoader
|
from langchain_community.document_loaders import TextLoader
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
@ -60,8 +60,8 @@
|
|||||||
"import getpass\n",
|
"import getpass\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
|
||||||
"from langchain_core.documents import Document\n",
|
"from langchain_core.documents import Document\n",
|
||||||
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"from langchain_openai import OpenAIEmbeddings\n",
|
"from langchain_openai import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip -q install elasticsearch langchain"
|
"!pip -q install langchain-elasticsearch"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -36,7 +36,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.embeddings.elasticsearch import ElasticsearchEmbeddings"
|
"from langchain_elasticsearch import ElasticsearchEmbeddings"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install --upgrade --quiet elasticsearch langchain-openai tiktoken langchain"
|
"%pip install --upgrade --quiet langchain-elasticsearch langchain-openai tiktoken langchain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -64,7 +64,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Example:\n",
|
"Example:\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
" from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
" from langchain_openai import OpenAIEmbeddings\n",
|
" from langchain_openai import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
" embedding = OpenAIEmbeddings()\n",
|
" embedding = OpenAIEmbeddings()\n",
|
||||||
@ -79,7 +79,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Example:\n",
|
"Example:\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
" from langchain_community.vectorstores import ElasticsearchStore\n",
|
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
" from langchain_openai import OpenAIEmbeddings\n",
|
" from langchain_openai import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
" embedding = OpenAIEmbeddings()\n",
|
" embedding = OpenAIEmbeddings()\n",
|
||||||
@ -97,7 +97,7 @@
|
|||||||
"Example:\n",
|
"Example:\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
" import elasticsearch\n",
|
" import elasticsearch\n",
|
||||||
" from langchain_community.vectorstores import ElasticsearchStore\n",
|
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"\n",
|
"\n",
|
||||||
" es_client= elasticsearch.Elasticsearch(\n",
|
" es_client= elasticsearch.Elasticsearch(\n",
|
||||||
" hosts=[\"http://localhost:9200\"],\n",
|
" hosts=[\"http://localhost:9200\"],\n",
|
||||||
@ -137,7 +137,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Example:\n",
|
"Example:\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
" from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
" from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
" from langchain_openai import OpenAIEmbeddings\n",
|
" from langchain_openai import OpenAIEmbeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
" embedding = OpenAIEmbeddings()\n",
|
" embedding = OpenAIEmbeddings()\n",
|
||||||
@ -202,7 +202,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"from langchain_openai import OpenAIEmbeddings"
|
"from langchain_openai import OpenAIEmbeddings"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -817,7 +817,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from typing import Dict\n",
|
"from typing import Dict\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from langchain.docstore.document import Document\n",
|
"from langchain_core.documents import Document\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def custom_document_builder(hit: Dict) -> Document:\n",
|
"def custom_document_builder(hit: Dict) -> Document:\n",
|
||||||
@ -902,7 +902,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"\n",
|
"\n",
|
||||||
"db = ElasticsearchStore(\n",
|
"db = ElasticsearchStore(\n",
|
||||||
" es_url=\"http://localhost:9200\",\n",
|
" es_url=\"http://localhost:9200\",\n",
|
||||||
@ -936,7 +936,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from langchain_community.vectorstores.elasticsearch import ElasticsearchStore\n",
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"\n",
|
"\n",
|
||||||
"db = ElasticsearchStore(\n",
|
"db = ElasticsearchStore(\n",
|
||||||
" es_url=\"http://localhost:9200\",\n",
|
" es_url=\"http://localhost:9200\",\n",
|
||||||
|
@ -91,8 +91,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.indexes import SQLRecordManager, index\n",
|
"from langchain.indexes import SQLRecordManager, index\n",
|
||||||
"from langchain_community.vectorstores import ElasticsearchStore\n",
|
|
||||||
"from langchain_core.documents import Document\n",
|
"from langchain_core.documents import Document\n",
|
||||||
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"from langchain_openai import OpenAIEmbeddings"
|
"from langchain_openai import OpenAIEmbeddings"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
1
libs/partners/elasticsearch/.gitignore
vendored
Normal file
1
libs/partners/elasticsearch/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
21
libs/partners/elasticsearch/LICENSE
Normal file
21
libs/partners/elasticsearch/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 LangChain, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
60
libs/partners/elasticsearch/Makefile
Normal file
60
libs/partners/elasticsearch/Makefile
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||||
|
|
||||||
|
# Default target executed when no arguments are given to make.
|
||||||
|
all: help
|
||||||
|
|
||||||
|
install:
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
# Define a variable for the test file path.
|
||||||
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||||
|
|
||||||
|
test tests integration_test integration_tests:
|
||||||
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
######################
|
||||||
|
# LINTING AND FORMATTING
|
||||||
|
######################
|
||||||
|
|
||||||
|
# Define a variable for Python and notebook files.
|
||||||
|
PYTHON_FILES=.
|
||||||
|
MYPY_CACHE=.mypy_cache
|
||||||
|
lint format: PYTHON_FILES=.
|
||||||
|
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/elasticsearch --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||||
|
lint_package: PYTHON_FILES=langchain_elasticsearch
|
||||||
|
lint_tests: PYTHON_FILES=tests
|
||||||
|
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||||
|
|
||||||
|
lint lint_diff lint_package lint_tests:
|
||||||
|
poetry run ruff .
|
||||||
|
poetry run ruff format $(PYTHON_FILES) --diff
|
||||||
|
poetry run ruff --select I $(PYTHON_FILES)
|
||||||
|
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
|
format format_diff:
|
||||||
|
poetry run ruff format $(PYTHON_FILES)
|
||||||
|
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||||
|
|
||||||
|
spell_check:
|
||||||
|
poetry run codespell --toml pyproject.toml
|
||||||
|
|
||||||
|
spell_fix:
|
||||||
|
poetry run codespell --toml pyproject.toml -w
|
||||||
|
|
||||||
|
check_imports: $(shell find langchain_elasticsearch -name '*.py')
|
||||||
|
poetry run python ./scripts/check_imports.py $^
|
||||||
|
|
||||||
|
######################
|
||||||
|
# HELP
|
||||||
|
######################
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo '----'
|
||||||
|
@echo 'check_imports - check imports'
|
||||||
|
@echo 'format - run code formatters'
|
||||||
|
@echo 'lint - run linters'
|
||||||
|
@echo 'test - run unit tests'
|
||||||
|
@echo 'tests - run unit tests'
|
||||||
|
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
29
libs/partners/elasticsearch/README.md
Normal file
29
libs/partners/elasticsearch/README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# langchain-elasticsearch
|
||||||
|
|
||||||
|
This package contains the LangChain integration with Elasticsearch.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U langchain-elasticsearch
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO document how to get id and key
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The `ElasticsearchStore` class exposes the connection to the Pinecone vector store.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_elasticsearch import ElasticsearchStore
|
||||||
|
|
||||||
|
embeddings = ... # use a LangChain Embeddings class
|
||||||
|
|
||||||
|
vectorstore = ElasticsearchStore(
|
||||||
|
es_cloud_id="your-cloud-id",
|
||||||
|
es_api_key="your-api-key",
|
||||||
|
index_name="your-index-name",
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,17 @@
|
|||||||
|
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
||||||
|
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||||
|
from langchain_elasticsearch.vectorstores import (
|
||||||
|
ApproxRetrievalStrategy,
|
||||||
|
ElasticsearchStore,
|
||||||
|
ExactRetrievalStrategy,
|
||||||
|
SparseRetrievalStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ApproxRetrievalStrategy",
|
||||||
|
"ElasticsearchChatMessageHistory",
|
||||||
|
"ElasticsearchEmbeddings",
|
||||||
|
"ElasticsearchStore",
|
||||||
|
"ExactRetrievalStrategy",
|
||||||
|
"SparseRetrievalStrategy",
|
||||||
|
]
|
@ -0,0 +1,82 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
class DistanceStrategy(str, Enum):
|
||||||
|
"""Enumerator of the Distance strategies for calculating distances
|
||||||
|
between vectors."""
|
||||||
|
|
||||||
|
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||||
|
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||||
|
DOT_PRODUCT = "DOT_PRODUCT"
|
||||||
|
JACCARD = "JACCARD"
|
||||||
|
COSINE = "COSINE"
|
||||||
|
|
||||||
|
|
||||||
|
def maximal_marginal_relevance(
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
embedding_list: list,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
k: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Calculate maximal marginal relevance."""
|
||||||
|
if min(k, len(embedding_list)) <= 0:
|
||||||
|
return []
|
||||||
|
if query_embedding.ndim == 1:
|
||||||
|
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||||
|
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||||
|
most_similar = int(np.argmax(similarity_to_query))
|
||||||
|
idxs = [most_similar]
|
||||||
|
selected = np.array([embedding_list[most_similar]])
|
||||||
|
while len(idxs) < min(k, len(embedding_list)):
|
||||||
|
best_score = -np.inf
|
||||||
|
idx_to_add = -1
|
||||||
|
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||||
|
for i, query_score in enumerate(similarity_to_query):
|
||||||
|
if i in idxs:
|
||||||
|
continue
|
||||||
|
redundant_score = max(similarity_to_selected[i])
|
||||||
|
equation_score = (
|
||||||
|
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||||
|
)
|
||||||
|
if equation_score > best_score:
|
||||||
|
best_score = equation_score
|
||||||
|
idx_to_add = i
|
||||||
|
idxs.append(idx_to_add)
|
||||||
|
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||||
|
return idxs
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||||
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||||
|
if len(X) == 0 or len(Y) == 0:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
X = np.array(X)
|
||||||
|
Y = np.array(Y)
|
||||||
|
if X.shape[1] != Y.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||||
|
f"and Y has shape {Y.shape}."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import simsimd as simd # type: ignore
|
||||||
|
|
||||||
|
X = np.array(X, dtype=np.float32)
|
||||||
|
Y = np.array(Y, dtype=np.float32)
|
||||||
|
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
||||||
|
if isinstance(Z, float):
|
||||||
|
return np.array([Z])
|
||||||
|
return Z
|
||||||
|
except ImportError:
|
||||||
|
X_norm = np.linalg.norm(X, axis=1)
|
||||||
|
Y_norm = np.linalg.norm(Y, axis=1)
|
||||||
|
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||||
|
with np.errstate(divide="ignore", invalid="ignore"):
|
||||||
|
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||||
|
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||||
|
return similarity
|
@ -0,0 +1,201 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from time import time
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
message_to_dict,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
"""Chat message history that stores history in Elasticsearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
es_url: URL of the Elasticsearch instance to connect to.
|
||||||
|
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
|
||||||
|
es_user: Username to use when connecting to Elasticsearch.
|
||||||
|
es_password: Password to use when connecting to Elasticsearch.
|
||||||
|
es_api_key: API key to use when connecting to Elasticsearch.
|
||||||
|
es_connection: Optional pre-existing Elasticsearch connection.
|
||||||
|
esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True.
|
||||||
|
index: Name of the index to use.
|
||||||
|
session_id: Arbitrary key that is used to store the messages
|
||||||
|
of a single chat session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index: str,
|
||||||
|
session_id: str,
|
||||||
|
*,
|
||||||
|
es_connection: Optional["Elasticsearch"] = None,
|
||||||
|
es_url: Optional[str] = None,
|
||||||
|
es_cloud_id: Optional[str] = None,
|
||||||
|
es_user: Optional[str] = None,
|
||||||
|
es_api_key: Optional[str] = None,
|
||||||
|
es_password: Optional[str] = None,
|
||||||
|
esnsure_ascii: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
self.index: str = index
|
||||||
|
self.session_id: str = session_id
|
||||||
|
self.ensure_ascii = esnsure_ascii
|
||||||
|
|
||||||
|
# Initialize Elasticsearch client from passed client arg or connection info
|
||||||
|
if es_connection is not None:
|
||||||
|
self.client = es_connection.options(
|
||||||
|
headers={"user-agent": self.get_user_agent()}
|
||||||
|
)
|
||||||
|
elif es_url is not None or es_cloud_id is not None:
|
||||||
|
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
|
||||||
|
es_url=es_url,
|
||||||
|
username=es_user,
|
||||||
|
password=es_password,
|
||||||
|
cloud_id=es_cloud_id,
|
||||||
|
api_key=es_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"""Either provide a pre-existing Elasticsearch connection, \
|
||||||
|
or valid credentials for creating a new connection."""
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.client.indices.exists(index=index):
|
||||||
|
logger.debug(
|
||||||
|
f"Chat history index {index} already exists, skipping creation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Creating index {index} for storing chat history.")
|
||||||
|
|
||||||
|
self.client.indices.create(
|
||||||
|
index=index,
|
||||||
|
mappings={
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "keyword"},
|
||||||
|
"created_at": {"type": "date"},
|
||||||
|
"history": {"type": "text"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agent() -> str:
|
||||||
|
from langchain_core import __version__
|
||||||
|
|
||||||
|
return f"langchain-py-ms/{__version__}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def connect_to_elasticsearch(
|
||||||
|
*,
|
||||||
|
es_url: Optional[str] = None,
|
||||||
|
cloud_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> "Elasticsearch":
|
||||||
|
try:
|
||||||
|
import elasticsearch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import elasticsearch python package. "
|
||||||
|
"Please install it with `pip install elasticsearch`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if es_url and cloud_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Both es_url and cloud_id are defined. Please provide only one."
|
||||||
|
)
|
||||||
|
|
||||||
|
connection_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if es_url:
|
||||||
|
connection_params["hosts"] = [es_url]
|
||||||
|
elif cloud_id:
|
||||||
|
connection_params["cloud_id"] = cloud_id
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
connection_params["api_key"] = api_key
|
||||||
|
elif username and password:
|
||||||
|
connection_params["basic_auth"] = (username, password)
|
||||||
|
|
||||||
|
es_client = elasticsearch.Elasticsearch(
|
||||||
|
**connection_params,
|
||||||
|
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
es_client.info()
|
||||||
|
except Exception as err:
|
||||||
|
logger.error(f"Error connecting to Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
return es_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[BaseMessage]: # type: ignore[override]
|
||||||
|
"""Retrieve the messages from Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
result = self.client.search(
|
||||||
|
index=self.index,
|
||||||
|
query={"term": {"session_id": self.session_id}},
|
||||||
|
sort="created_at:asc",
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
if result and len(result["hits"]["hits"]) > 0:
|
||||||
|
items = [
|
||||||
|
json.loads(document["_source"]["history"])
|
||||||
|
for document in result["hits"]["hits"]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
items = []
|
||||||
|
|
||||||
|
return messages_from_dict(items)
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a message to the chat session in Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
self.client.index(
|
||||||
|
index=self.index,
|
||||||
|
document={
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"created_at": round(time() * 1000),
|
||||||
|
"history": json.dumps(
|
||||||
|
message_to_dict(message),
|
||||||
|
ensure_ascii=bool(self.ensure_ascii),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not add message to Elasticsearch: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear session memory in Elasticsearch"""
|
||||||
|
try:
|
||||||
|
from elasticsearch import ApiError
|
||||||
|
|
||||||
|
self.client.delete_by_query(
|
||||||
|
index=self.index,
|
||||||
|
query={"term": {"session_id": self.session_id}},
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
except ApiError as err:
|
||||||
|
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
|
||||||
|
raise err
|
@ -0,0 +1,208 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.utils import get_from_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from elasticsearch.client import MlClient
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchEmbeddings(Embeddings):
|
||||||
|
"""Elasticsearch embedding models.
|
||||||
|
|
||||||
|
This class provides an interface to generate embeddings using a model deployed
|
||||||
|
in an Elasticsearch cluster. It requires an Elasticsearch connection object
|
||||||
|
and the model_id of the model deployed in the cluster.
|
||||||
|
|
||||||
|
In Elasticsearch you need to have an embedding model loaded and deployed.
|
||||||
|
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
|
||||||
|
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: MlClient,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
input_field: str = "text_field",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the ElasticsearchEmbeddings instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (MlClient): An Elasticsearch ML client object.
|
||||||
|
model_id (str): The model_id of the model deployed in the Elasticsearch
|
||||||
|
cluster.
|
||||||
|
input_field (str): The name of the key for the input text field in the
|
||||||
|
document. Defaults to 'text_field'.
|
||||||
|
"""
|
||||||
|
self.client = client
|
||||||
|
self.model_id = model_id
|
||||||
|
self.input_field = input_field
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_credentials(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
es_cloud_id: Optional[str] = None,
|
||||||
|
es_api_key: Optional[str] = None,
|
||||||
|
input_field: str = "text_field",
|
||||||
|
) -> ElasticsearchEmbeddings:
|
||||||
|
"""Instantiate embeddings from Elasticsearch credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (str): The model_id of the model deployed in the Elasticsearch
|
||||||
|
cluster.
|
||||||
|
input_field (str): The name of the key for the input text field in the
|
||||||
|
document. Defaults to 'text_field'.
|
||||||
|
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
|
||||||
|
es_user: (str, optional): Elasticsearch username.
|
||||||
|
es_password: (str, optional): Elasticsearch password.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_elasticserach.embeddings import ElasticsearchEmbeddings
|
||||||
|
|
||||||
|
# Define the model ID and input field name (if different from default)
|
||||||
|
model_id = "your_model_id"
|
||||||
|
# Optional, only if different from 'text_field'
|
||||||
|
input_field = "your_input_field"
|
||||||
|
|
||||||
|
# Credentials can be passed in two ways. Either set the env vars
|
||||||
|
# ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically
|
||||||
|
# pulled in, or pass them in directly as kwargs.
|
||||||
|
embeddings = ElasticsearchEmbeddings.from_credentials(
|
||||||
|
model_id,
|
||||||
|
input_field=input_field,
|
||||||
|
# es_cloud_id="foo",
|
||||||
|
# es_user="bar",
|
||||||
|
# es_password="baz",
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
"This is an example document.",
|
||||||
|
"Another example document to generate embeddings for.",
|
||||||
|
]
|
||||||
|
embeddings_generator.embed_documents(documents)
|
||||||
|
"""
|
||||||
|
from elasticsearch.client import MlClient
|
||||||
|
|
||||||
|
es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID")
|
||||||
|
es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY")
|
||||||
|
|
||||||
|
# Connect to Elasticsearch
|
||||||
|
es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key)
|
||||||
|
client = MlClient(es_connection)
|
||||||
|
return cls(client, model_id, input_field=input_field)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_es_connection(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
es_connection: Elasticsearch,
|
||||||
|
input_field: str = "text_field",
|
||||||
|
) -> ElasticsearchEmbeddings:
|
||||||
|
"""
|
||||||
|
Instantiate embeddings from an existing Elasticsearch connection.
|
||||||
|
|
||||||
|
This method provides a way to create an instance of the ElasticsearchEmbeddings
|
||||||
|
class using an existing Elasticsearch connection. The connection object is used
|
||||||
|
to create an MlClient, which is then used to initialize the
|
||||||
|
ElasticsearchEmbeddings instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
|
||||||
|
es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
|
||||||
|
connection object. input_field (str, optional): The name of the key for the
|
||||||
|
input text field in the document. Defaults to 'text_field'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||||
|
|
||||||
|
# Define the model ID and input field name (if different from default)
|
||||||
|
model_id = "your_model_id"
|
||||||
|
# Optional, only if different from 'text_field'
|
||||||
|
input_field = "your_input_field"
|
||||||
|
|
||||||
|
# Create Elasticsearch connection
|
||||||
|
es_connection = Elasticsearch(
|
||||||
|
hosts=["localhost:9200"], http_auth=("user", "password")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate ElasticsearchEmbeddings using the existing connection
|
||||||
|
embeddings = ElasticsearchEmbeddings.from_es_connection(
|
||||||
|
model_id,
|
||||||
|
es_connection,
|
||||||
|
input_field=input_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
"This is an example document.",
|
||||||
|
"Another example document to generate embeddings for.",
|
||||||
|
]
|
||||||
|
embeddings_generator.embed_documents(documents)
|
||||||
|
"""
|
||||||
|
from elasticsearch.client import MlClient
|
||||||
|
|
||||||
|
# Create an MlClient from the given Elasticsearch connection
|
||||||
|
client = MlClient(es_connection)
|
||||||
|
|
||||||
|
# Return a new instance of the ElasticsearchEmbeddings class with
|
||||||
|
# the MlClient, model_id, and input_field
|
||||||
|
return cls(client, model_id, input_field=input_field)
|
||||||
|
|
||||||
|
def _embedding_func(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for the given texts using the Elasticsearch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): A list of text strings to generate embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: A list of embeddings, one for each text in the input
|
||||||
|
list.
|
||||||
|
"""
|
||||||
|
response = self.client.infer_trained_model(
|
||||||
|
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): A list of document text strings to generate embeddings
|
||||||
|
for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: A list of embeddings, one for each document in the input
|
||||||
|
list.
|
||||||
|
"""
|
||||||
|
return self._embedding_func(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Generate an embedding for a single query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The query text to generate an embedding for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: The embedding for the input query text.
|
||||||
|
"""
|
||||||
|
return self._embedding_func([text])[0]
|
1285
libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py
Normal file
1285
libs/partners/elasticsearch/langchain_elasticsearch/vectorstores.py
Normal file
File diff suppressed because it is too large
Load Diff
1655
libs/partners/elasticsearch/poetry.lock
generated
Normal file
1655
libs/partners/elasticsearch/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
96
libs/partners/elasticsearch/pyproject.toml
Normal file
96
libs/partners/elasticsearch/pyproject.toml
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "langchain-elasticsearch"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "An integration package connecting Elasticsearch and LangChain"
|
||||||
|
authors = []
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/langchain-ai/langchain"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
|
[tool.poetry.urls]
|
||||||
|
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/elasticsearch"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.8.1,<4.0"
|
||||||
|
langchain-core = "^0.1"
|
||||||
|
elasticsearch = "^8.12.0"
|
||||||
|
numpy = "^1"
|
||||||
|
|
||||||
|
[tool.poetry.group.test]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test.dependencies]
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
freezegun = "^1.2.2"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
syrupy = "^4.0.2"
|
||||||
|
pytest-watcher = "^0.3.4"
|
||||||
|
pytest-asyncio = "^0.21.1"
|
||||||
|
langchain = { path = "../../langchain", develop = true }
|
||||||
|
langchain-community = { path = "../../community", develop = true }
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.codespell.dependencies]
|
||||||
|
codespell = "^2.2.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.lint]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.lint.dependencies]
|
||||||
|
ruff = "^0.1.5"
|
||||||
|
|
||||||
|
[tool.poetry.group.typing.dependencies]
|
||||||
|
mypy = "^0.991"
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.poetry.group.dev]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
langchain-core = { path = "../../core", develop = true }
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration]
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration.dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
disallow_untyped_defs = "True"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
omit = ["tests/*"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
# --strict-markers will raise errors on unknown marks.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||||
|
#
|
||||||
|
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||||
|
# --strict-config any warnings encountered while parsing the `pytest`
|
||||||
|
# section of the configuration file raise errors.
|
||||||
|
#
|
||||||
|
# https://github.com/tophat/syrupy
|
||||||
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
|
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||||
|
# Registering custom markers.
|
||||||
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library",
|
||||||
|
"asyncio: mark tests as requiring asyncio",
|
||||||
|
"compile: mark placeholder test used to compile integration tests without running them",
|
||||||
|
]
|
||||||
|
asyncio_mode = "auto"
|
17
libs/partners/elasticsearch/scripts/check_imports.py
Normal file
17
libs/partners/elasticsearch/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from importlib.machinery import SourceFileLoader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
files = sys.argv[1:]
|
||||||
|
has_failure = False
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
SourceFileLoader("x", file).load_module()
|
||||||
|
except Exception:
|
||||||
|
has_faillure = True
|
||||||
|
print(file)
|
||||||
|
traceback.print_exc()
|
||||||
|
print()
|
||||||
|
|
||||||
|
sys.exit(1 if has_failure else 0)
|
27
libs/partners/elasticsearch/scripts/check_pydantic.sh
Executable file
27
libs/partners/elasticsearch/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||||
|
# in tracked files within a Git repository.
|
||||||
|
#
|
||||||
|
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||||
|
|
||||||
|
# Check if a path argument is provided
|
||||||
|
if [ $# -ne 1 ]; then
|
||||||
|
echo "Usage: $0 /path/to/repository"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
repository_path="$1"
|
||||||
|
|
||||||
|
# Search for lines matching the pattern within the specified repository
|
||||||
|
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||||
|
|
||||||
|
# Check if any matching lines were found
|
||||||
|
if [ -n "$result" ]; then
|
||||||
|
echo "ERROR: The following lines need to be updated:"
|
||||||
|
echo "$result"
|
||||||
|
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||||
|
echo "For example, replace 'from pydantic import BaseModel'"
|
||||||
|
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||||
|
exit 1
|
||||||
|
fi
|
17
libs/partners/elasticsearch/scripts/lint_imports.sh
Executable file
17
libs/partners/elasticsearch/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# Initialize a variable to keep track of errors
|
||||||
|
errors=0
|
||||||
|
|
||||||
|
# make sure not importing from langchain or langchain_experimental
|
||||||
|
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||||
|
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||||
|
|
||||||
|
# Decide on an exit status based on the errors
|
||||||
|
if [ "$errors" -gt 0 ]; then
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
exit 0
|
||||||
|
fi
|
0
libs/partners/elasticsearch/tests/__init__.py
Normal file
0
libs/partners/elasticsearch/tests/__init__.py
Normal file
55
libs/partners/elasticsearch/tests/fake_embeddings.py
Normal file
55
libs/partners/elasticsearch/tests/fake_embeddings.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Fake Embedding class for testing purposes."""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
fake_texts = ["foo", "bar", "baz"]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmbeddings(Embeddings):
|
||||||
|
"""Fake embeddings functionality for testing."""
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return simple embeddings.
|
||||||
|
Embeddings encode each text as its index."""
|
||||||
|
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self.embed_documents(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return constant query embeddings.
|
||||||
|
Embeddings are identical to embed_documents(texts)[0].
|
||||||
|
Distance to each text will be that text's index,
|
||||||
|
as it was passed to embed_documents."""
|
||||||
|
return [float(1.0)] * 9 + [float(0.0)]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
return self.embed_query(text)
|
||||||
|
|
||||||
|
|
||||||
|
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||||
|
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||||
|
vectors for the same texts."""
|
||||||
|
|
||||||
|
def __init__(self, dimensionality: int = 10) -> None:
|
||||||
|
self.known_texts: List[str] = []
|
||||||
|
self.dimensionality = dimensionality
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return consistent embeddings for each text seen so far."""
|
||||||
|
out_vectors = []
|
||||||
|
for text in texts:
|
||||||
|
if text not in self.known_texts:
|
||||||
|
self.known_texts.append(text)
|
||||||
|
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
||||||
|
float(self.known_texts.index(text))
|
||||||
|
]
|
||||||
|
out_vectors.append(vector)
|
||||||
|
return out_vectors
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||||
|
one if the text is unknown."""
|
||||||
|
return self.embed_documents([text])[0]
|
@ -0,0 +1,89 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Generator, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain_core.messages import message_to_dict
|
||||||
|
|
||||||
|
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
|
||||||
|
|
||||||
|
"""
|
||||||
|
cd tests/integration_tests/memory/docker-compose
|
||||||
|
docker-compose -f elasticsearch.yml up
|
||||||
|
|
||||||
|
By default runs against local docker instance of Elasticsearch.
|
||||||
|
To run against Elastic Cloud, set the following environment variables:
|
||||||
|
- ES_CLOUD_ID
|
||||||
|
- ES_USERNAME
|
||||||
|
- ES_PASSWORD
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestElasticsearch:
|
||||||
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
|
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||||
|
# Run this integration test against Elasticsearch on localhost,
|
||||||
|
# or an Elastic Cloud instance
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||||
|
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||||
|
es_api_key = os.environ.get("ES_API_KEY")
|
||||||
|
|
||||||
|
if es_cloud_id:
|
||||||
|
es = Elasticsearch(
|
||||||
|
cloud_id=es_cloud_id,
|
||||||
|
api_key=es_api_key,
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"es_cloud_id": es_cloud_id,
|
||||||
|
"es_api_key": es_api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Running this integration test with local docker instance
|
||||||
|
es = Elasticsearch(hosts=es_url)
|
||||||
|
yield {"es_url": es_url}
|
||||||
|
|
||||||
|
# Clear all indexes
|
||||||
|
index_names = es.indices.get(index="_all").keys()
|
||||||
|
for index_name in index_names:
|
||||||
|
if index_name.startswith("test_"):
|
||||||
|
es.indices.delete(index=index_name)
|
||||||
|
es.indices.refresh(index="_all")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def index_name(self) -> str:
|
||||||
|
"""Return the index name."""
|
||||||
|
return f"test_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
def test_memory_with_message_store(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test the memory with a message store."""
|
||||||
|
# setup Elasticsearch as a message store
|
||||||
|
message_history = ElasticsearchChatMessageHistory(
|
||||||
|
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = ConversationBufferMemory(
|
||||||
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# add some messages
|
||||||
|
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||||
|
memory.chat_memory.add_user_message("This is me, the human")
|
||||||
|
|
||||||
|
# get the message history from the memory store and turn it into a json
|
||||||
|
messages = memory.chat_memory.messages
|
||||||
|
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||||
|
|
||||||
|
assert "This is me, the AI" in messages_json
|
||||||
|
assert "This is me, the human" in messages_json
|
||||||
|
|
||||||
|
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||||
|
memory.chat_memory.clear()
|
||||||
|
|
||||||
|
assert memory.chat_memory.messages == []
|
@ -0,0 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.compile
|
||||||
|
def test_placeholder() -> None:
|
||||||
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
|
pass
|
@ -0,0 +1,48 @@
|
|||||||
|
"""Test elasticsearch_embeddings embeddings."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.utils import get_from_env
|
||||||
|
|
||||||
|
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
|
||||||
|
|
||||||
|
# deployed with
|
||||||
|
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
|
||||||
|
DEFAULT_MODEL = "sentence-transformers__msmarco-minilm-l-12-v3"
|
||||||
|
DEFAULT_NUM_DIMENSIONS = "384"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_id() -> str:
|
||||||
|
return get_from_env("model_id", "MODEL_ID", DEFAULT_MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_num_dimensions() -> int:
|
||||||
|
return int(
|
||||||
|
get_from_env(
|
||||||
|
"expected_num_dimensions", "EXPECTED_NUM_DIMENSIONS", DEFAULT_NUM_DIMENSIONS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_elasticsearch_embedding_documents(
|
||||||
|
model_id: str, expected_num_dimensions: int
|
||||||
|
) -> None:
|
||||||
|
"""Test Elasticsearch embedding documents."""
|
||||||
|
documents = ["foo bar", "bar foo", "foo"]
|
||||||
|
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 3
|
||||||
|
assert len(output[0]) == expected_num_dimensions
|
||||||
|
assert len(output[1]) == expected_num_dimensions
|
||||||
|
assert len(output[2]) == expected_num_dimensions
|
||||||
|
|
||||||
|
|
||||||
|
def test_elasticsearch_embedding_query(
|
||||||
|
model_id: str, expected_num_dimensions: int
|
||||||
|
) -> None:
|
||||||
|
"""Test Elasticsearch embedding query."""
|
||||||
|
document = "foo bar"
|
||||||
|
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
|
||||||
|
output = embedding.embed_query(document)
|
||||||
|
assert len(output) == expected_num_dimensions
|
@ -0,0 +1,931 @@
|
|||||||
|
"""Test ElasticsearchStore functionality."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Generator, List, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from elastic_transport import Transport
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from elasticsearch.helpers import BulkIndexError
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
||||||
|
|
||||||
|
from ..fake_embeddings import (
|
||||||
|
ConsistentFakeEmbeddings,
|
||||||
|
FakeEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
"""
|
||||||
|
cd tests/integration_tests/vectorstores/docker-compose
|
||||||
|
docker-compose -f elasticsearch.yml up
|
||||||
|
|
||||||
|
By default runs against local docker instance of Elasticsearch.
|
||||||
|
To run against Elastic Cloud, set the following environment variables:
|
||||||
|
- ES_CLOUD_ID
|
||||||
|
- ES_API_KEY
|
||||||
|
|
||||||
|
Some of the tests require the following models to be deployed in the ML Node:
|
||||||
|
- elser (can be downloaded and deployed through Kibana and trained models UI)
|
||||||
|
- sentence-transformers__all-minilm-l6-v2 (can be deployed
|
||||||
|
through API, loaded via eland)
|
||||||
|
|
||||||
|
These tests that require the models to be deployed are skipped by default.
|
||||||
|
Enable them by adding the model name to the modelsDeployed list below.
|
||||||
|
"""
|
||||||
|
|
||||||
|
modelsDeployed: List[str] = [
|
||||||
|
# "elser",
|
||||||
|
# "sentence-transformers__all-minilm-l6-v2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestElasticsearch:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls) -> None:
|
||||||
|
if not os.getenv("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
|
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]:
|
||||||
|
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||||
|
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||||
|
api_key = os.environ.get("ES_API_KEY")
|
||||||
|
|
||||||
|
if cloud_id:
|
||||||
|
# Running this integration test with Elastic Cloud
|
||||||
|
# Required for in-stack inference testing (ELSER + model_id)
|
||||||
|
es = Elasticsearch(
|
||||||
|
cloud_id=cloud_id,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"es_cloud_id": cloud_id,
|
||||||
|
"es_api_key": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Running this integration test with local docker instance
|
||||||
|
es = Elasticsearch(hosts=es_url)
|
||||||
|
yield {"es_url": es_url}
|
||||||
|
|
||||||
|
# Clear all indexes
|
||||||
|
index_names = es.indices.get(index="_all").keys()
|
||||||
|
for index_name in index_names:
|
||||||
|
if index_name.startswith("test_"):
|
||||||
|
es.indices.delete(index=index_name)
|
||||||
|
es.indices.refresh(index="_all")
|
||||||
|
|
||||||
|
# clear all test pipelines
|
||||||
|
try:
|
||||||
|
response = es.ingest.get_pipeline(id="test_*,*_sparse_embedding")
|
||||||
|
|
||||||
|
for pipeline_id, _ in response.items():
|
||||||
|
try:
|
||||||
|
es.ingest.delete_pipeline(id=pipeline_id)
|
||||||
|
print(f"Deleted pipeline: {pipeline_id}") # noqa: T201
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Pipeline error: {e}") # noqa: T201
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def es_client(self) -> Any:
|
||||||
|
class CustomTransport(Transport):
|
||||||
|
requests = []
|
||||||
|
|
||||||
|
def perform_request(self, *args, **kwargs): # type: ignore
|
||||||
|
self.requests.append(kwargs)
|
||||||
|
return super().perform_request(*args, **kwargs)
|
||||||
|
|
||||||
|
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||||
|
cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||||
|
api_key = os.environ.get("ES_API_KEY")
|
||||||
|
|
||||||
|
if cloud_id:
|
||||||
|
# Running this integration test with Elastic Cloud
|
||||||
|
# Required for in-stack inference testing (ELSER + model_id)
|
||||||
|
es = Elasticsearch(
|
||||||
|
cloud_id=cloud_id,
|
||||||
|
api_key=api_key,
|
||||||
|
transport_class=CustomTransport,
|
||||||
|
)
|
||||||
|
return es
|
||||||
|
else:
|
||||||
|
# Running this integration test with local docker instance
|
||||||
|
es = Elasticsearch(hosts=es_url, transport_class=CustomTransport)
|
||||||
|
return es
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def index_name(self) -> str:
|
||||||
|
"""Return the index name."""
|
||||||
|
return f"test_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
def test_similarity_search_without_metadata(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search without metadata."""
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 1,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
async def test_similarity_search_without_metadata_async(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search without metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
output = await docsearch.asimilarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_add_embeddings(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test add_embeddings, which accepts pre-built embeddings instead of
|
||||||
|
using inference for the texts.
|
||||||
|
This allows you to separate the embeddings text and the page_content
|
||||||
|
for better proximity between user's question and embedded text.
|
||||||
|
For example, your embedding text can be a question, whereas page_content
|
||||||
|
is the answer.
|
||||||
|
"""
|
||||||
|
embeddings = ConsistentFakeEmbeddings()
|
||||||
|
text_input = ["foo1", "foo2", "foo3"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(text_input))]
|
||||||
|
|
||||||
|
"""In real use case, embedding_input can be questions for each text"""
|
||||||
|
embedding_input = ["foo2", "foo3", "foo1"]
|
||||||
|
embedding_vectors = embeddings.embed_documents(embedding_input)
|
||||||
|
|
||||||
|
docsearch = ElasticsearchStore._create_cls_from_kwargs(
|
||||||
|
embeddings,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
docsearch.add_embeddings(list(zip(text_input, embedding_vectors)), metadatas)
|
||||||
|
output = docsearch.similarity_search("foo1", k=1)
|
||||||
|
assert output == [Document(page_content="foo3", metadata={"page": 2})]
|
||||||
|
|
||||||
|
def test_similarity_search_with_metadata(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
ConsistentFakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("bar", k=1)
|
||||||
|
assert output == [Document(page_content="bar", metadata={"page": 1})]
|
||||||
|
|
||||||
|
def test_similarity_search_with_filter(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "foo", "foo"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [{"term": {"metadata.page": "1"}}],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
query="foo",
|
||||||
|
k=3,
|
||||||
|
filter=[{"term": {"metadata.page": "1"}}],
|
||||||
|
custom_query=assert_query,
|
||||||
|
)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": 1})]
|
||||||
|
|
||||||
|
def test_similarity_search_with_doc_builder(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
texts = ["foo", "foo", "foo"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def custom_document_builder(_: Dict) -> Document:
|
||||||
|
return Document(
|
||||||
|
page_content="Mock content!",
|
||||||
|
metadata={
|
||||||
|
"page_number": -1,
|
||||||
|
"original_filename": "Mock filename!",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
query="foo", k=1, doc_builder=custom_document_builder
|
||||||
|
)
|
||||||
|
assert output[0].page_content == "Mock content!"
|
||||||
|
assert output[0].metadata["page_number"] == -1
|
||||||
|
assert output[0].metadata["original_filename"] == "Mock filename!"
|
||||||
|
|
||||||
|
def test_similarity_search_exact_search(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_query = {
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
||||||
|
"params": {
|
||||||
|
"query_vector": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == expected_query
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_similarity_search_exact_search_with_filter(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
metadatas=metadatas,
|
||||||
|
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
expected_query = {
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"bool": {"filter": [{"term": {"metadata.page": 0}}]}},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", # noqa: E501
|
||||||
|
"params": {
|
||||||
|
"query_vector": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert query_body == expected_query
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo",
|
||||||
|
k=1,
|
||||||
|
custom_query=assert_query,
|
||||||
|
filter=[{"term": {"metadata.page": 0}}],
|
||||||
|
)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": 0})]
|
||||||
|
|
||||||
|
def test_similarity_search_exact_search_distance_dot_product(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||||
|
distance_strategy="DOT_PRODUCT",
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == {
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"script": {
|
||||||
|
"source": """
|
||||||
|
double value = dotProduct(params.query_vector, 'vector');
|
||||||
|
return sigmoid(1, Math.E, -value);
|
||||||
|
""",
|
||||||
|
"params": {
|
||||||
|
"query_vector": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_similarity_search_exact_search_unknown_distance_strategy(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with unknown distance strategy."""
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||||
|
distance_strategy="NOT_A_STRATEGY",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_max_marginal_relevance_search(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test max marginal relevance search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ExactRetrievalStrategy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
|
||||||
|
sim_output = docsearch.similarity_search(texts[0], k=3)
|
||||||
|
assert mmr_output == sim_output
|
||||||
|
|
||||||
|
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
|
||||||
|
assert len(mmr_output) == 2
|
||||||
|
assert mmr_output[0].page_content == texts[0]
|
||||||
|
assert mmr_output[1].page_content == texts[1]
|
||||||
|
|
||||||
|
mmr_output = docsearch.max_marginal_relevance_search(
|
||||||
|
texts[0],
|
||||||
|
k=2,
|
||||||
|
fetch_k=3,
|
||||||
|
lambda_mult=0.1, # more diversity
|
||||||
|
)
|
||||||
|
assert len(mmr_output) == 2
|
||||||
|
assert mmr_output[0].page_content == texts[0]
|
||||||
|
assert mmr_output[1].page_content == texts[2]
|
||||||
|
|
||||||
|
# if fetch_k < k, then the output will be less than k
|
||||||
|
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
|
||||||
|
assert len(mmr_output) == 2
|
||||||
|
|
||||||
|
def test_similarity_search_approx_with_hybrid_search(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 1,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"filter": [],
|
||||||
|
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rank": {"rrf": {}},
|
||||||
|
}
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_similarity_search_approx_with_hybrid_search_rrf(
|
||||||
|
self, es_client: Any, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test end to end construction and rrf hybrid search with metadata."""
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# 1. check query_body is okay
|
||||||
|
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
{"rank_constant": 1, "window_size": 5},
|
||||||
|
]
|
||||||
|
for rrf_test_case in rrf_test_cases:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||||
|
hybrid=True, rrf=rrf_test_case
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_query(
|
||||||
|
query_body: dict,
|
||||||
|
query: str,
|
||||||
|
rrf: Optional[Union[dict, bool]] = True,
|
||||||
|
) -> dict:
|
||||||
|
cmp_query_body = {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"filter": [],
|
||||||
|
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(rrf, dict):
|
||||||
|
cmp_query_body["rank"] = {"rrf": rrf}
|
||||||
|
elif isinstance(rrf, bool) and rrf is True:
|
||||||
|
cmp_query_body["rank"] = {"rrf": {}}
|
||||||
|
|
||||||
|
assert query_body == cmp_query_body
|
||||||
|
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
## without fetch_k parameter
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. check query result is okay
|
||||||
|
es_output = es_client.search(
|
||||||
|
index=index_name,
|
||||||
|
query={
|
||||||
|
"bool": {
|
||||||
|
"filter": [],
|
||||||
|
"must": [{"match": {"text": {"query": "foo"}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
knn={
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 3,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
},
|
||||||
|
size=3,
|
||||||
|
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [o.page_content for o in output] == [
|
||||||
|
e["_source"]["text"] for e in es_output["hits"]["hits"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. check rrf default option is okay
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
## with fetch_k parameter
|
||||||
|
output = docsearch.similarity_search(
|
||||||
|
"foo", k=3, fetch_k=50, custom_query=assert_query
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_similarity_search_approx_with_custom_query_fn(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""test that custom query function is called
|
||||||
|
with the query string and query body"""
|
||||||
|
|
||||||
|
def my_custom_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query == "foo"
|
||||||
|
assert query_body == {
|
||||||
|
"knn": {
|
||||||
|
"field": "vector",
|
||||||
|
"filter": [],
|
||||||
|
"k": 1,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {"query": {"match": {"text": {"query": "bar"}}}}
|
||||||
|
|
||||||
|
"""Test end to end construction and search with metadata."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts, FakeEmbeddings(), **elasticsearch_connection, index_name=index_name
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=my_custom_query)
|
||||||
|
assert output == [Document(page_content="bar")]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
"sentence-transformers__all-minilm-l6-v2" not in modelsDeployed,
|
||||||
|
reason="Sentence Transformers model not deployed in ML Node, skipping test",
|
||||||
|
)
|
||||||
|
def test_similarity_search_with_approx_infer_instack(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""test end to end with approx retrieval strategy and inference in-stack"""
|
||||||
|
docsearch = ElasticsearchStore(
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||||
|
query_model_id="sentence-transformers__all-minilm-l6-v2"
|
||||||
|
),
|
||||||
|
query_field="text_field",
|
||||||
|
vector_query_field="vector_query_field.predicted_value",
|
||||||
|
**elasticsearch_connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# setting up the pipeline for inference
|
||||||
|
docsearch.client.ingest.put_pipeline(
|
||||||
|
id="test_pipeline",
|
||||||
|
processors=[
|
||||||
|
{
|
||||||
|
"inference": {
|
||||||
|
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
||||||
|
"field_map": {"query_field": "text_field"},
|
||||||
|
"target_field": "vector_query_field",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# creating a new index with the pipeline,
|
||||||
|
# not relying on langchain to create the index
|
||||||
|
docsearch.client.indices.create(
|
||||||
|
index=index_name,
|
||||||
|
mappings={
|
||||||
|
"properties": {
|
||||||
|
"text_field": {"type": "text"},
|
||||||
|
"vector_query_field": {
|
||||||
|
"properties": {
|
||||||
|
"predicted_value": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 384,
|
||||||
|
"index": True,
|
||||||
|
"similarity": "l2_norm",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
settings={"index": {"default_pipeline": "test_pipeline"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# adding documents to the index
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
docsearch.client.create(
|
||||||
|
index=index_name,
|
||||||
|
id=str(i),
|
||||||
|
document={"text_field": text, "metadata": {}},
|
||||||
|
)
|
||||||
|
|
||||||
|
docsearch.client.indices.refresh(index=index_name)
|
||||||
|
|
||||||
|
def assert_query(query_body: dict, query: str) -> dict:
|
||||||
|
assert query_body == {
|
||||||
|
"knn": {
|
||||||
|
"filter": [],
|
||||||
|
"field": "vector_query_field.predicted_value",
|
||||||
|
"k": 1,
|
||||||
|
"num_candidates": 50,
|
||||||
|
"query_vector_builder": {
|
||||||
|
"text_embedding": {
|
||||||
|
"model_id": "sentence-transformers__all-minilm-l6-v2",
|
||||||
|
"model_text": "foo",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query_body
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("bar", k=1)
|
||||||
|
assert output == [Document(page_content="bar")]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
"elser" not in modelsDeployed,
|
||||||
|
reason="ELSER not deployed in ML Node, skipping test",
|
||||||
|
)
|
||||||
|
def test_similarity_search_with_sparse_infer_instack(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""test end to end with sparse retrieval strategy and inference in-stack"""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
def test_elasticsearch_with_relevance_score(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
embeddings = FakeEmbeddings()
|
||||||
|
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
index_name=index_name,
|
||||||
|
texts=texts,
|
||||||
|
embedding=embeddings,
|
||||||
|
metadatas=metadatas,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedded_query = embeddings.embed_query("foo")
|
||||||
|
output = docsearch.similarity_search_by_vector_with_relevance_scores(
|
||||||
|
embedding=embedded_query, k=1
|
||||||
|
)
|
||||||
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
|
||||||
|
|
||||||
|
def test_elasticsearch_with_relevance_threshold(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test to make sure the relevance threshold is respected."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
embeddings = FakeEmbeddings()
|
||||||
|
|
||||||
|
docsearch = ElasticsearchStore.from_texts(
|
||||||
|
index_name=index_name,
|
||||||
|
texts=texts,
|
||||||
|
embedding=embeddings,
|
||||||
|
metadatas=metadatas,
|
||||||
|
**elasticsearch_connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find a good threshold for testing
|
||||||
|
query_string = "foo"
|
||||||
|
embedded_query = embeddings.embed_query(query_string)
|
||||||
|
top3 = docsearch.similarity_search_by_vector_with_relevance_scores(
|
||||||
|
embedding=embedded_query, k=3
|
||||||
|
)
|
||||||
|
similarity_of_second_ranked = top3[1][1]
|
||||||
|
assert len(top3) == 3
|
||||||
|
|
||||||
|
# Test threshold
|
||||||
|
retriever = docsearch.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={"score_threshold": similarity_of_second_ranked},
|
||||||
|
)
|
||||||
|
output = retriever.get_relevant_documents(query=query_string)
|
||||||
|
|
||||||
|
assert output == [
|
||||||
|
top3[0][0],
|
||||||
|
top3[1][0],
|
||||||
|
# third ranked is out
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_elasticsearch_delete_ids(
|
||||||
|
self, elasticsearch_connection: dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test delete methods from vector store."""
|
||||||
|
texts = ["foo", "bar", "baz", "gni"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = ElasticsearchStore(
|
||||||
|
embedding=ConsistentFakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = docsearch.add_texts(texts, metadatas)
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 4
|
||||||
|
|
||||||
|
docsearch.delete(ids[1:3])
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 2
|
||||||
|
|
||||||
|
docsearch.delete(["not-existing"])
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 2
|
||||||
|
|
||||||
|
docsearch.delete([ids[0]])
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 1
|
||||||
|
|
||||||
|
docsearch.delete([ids[3]])
|
||||||
|
output = docsearch.similarity_search("gni", k=10)
|
||||||
|
assert len(output) == 0
|
||||||
|
|
||||||
|
def test_elasticsearch_indexing_exception_error(
|
||||||
|
self,
|
||||||
|
elasticsearch_connection: dict,
|
||||||
|
index_name: str,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test bulk exception logging is giving better hints."""
|
||||||
|
|
||||||
|
docsearch = ElasticsearchStore(
|
||||||
|
embedding=ConsistentFakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
docsearch.client.indices.create(
|
||||||
|
index=index_name,
|
||||||
|
mappings={"properties": {}},
|
||||||
|
settings={"index": {"default_pipeline": "not-existing-pipeline"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
texts = ["foo"]
|
||||||
|
|
||||||
|
with pytest.raises(BulkIndexError):
|
||||||
|
docsearch.add_texts(texts)
|
||||||
|
|
||||||
|
error_reason = "pipeline with id [not-existing-pipeline] does not exist"
|
||||||
|
log_message = f"First error reason: {error_reason}"
|
||||||
|
|
||||||
|
assert log_message in caplog.text
|
||||||
|
|
||||||
|
def test_elasticsearch_with_user_agent(
|
||||||
|
self, es_client: Any, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test to make sure the user-agent is set correctly."""
|
||||||
|
|
||||||
|
texts = ["foo", "bob", "baz"]
|
||||||
|
ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
es_connection=es_client,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_agent = es_client.transport.requests[0]["headers"]["User-Agent"]
|
||||||
|
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||||
|
match = re.match(pattern, user_agent)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
match is not None
|
||||||
|
), f"The string '{user_agent}' does not match the expected pattern."
|
||||||
|
|
||||||
|
def test_elasticsearch_with_internal_user_agent(
|
||||||
|
self, elasticsearch_connection: Dict, index_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Test to make sure the user-agent is set correctly."""
|
||||||
|
|
||||||
|
texts = ["foo"]
|
||||||
|
store = ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
**elasticsearch_connection,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_agent = store.client._headers["User-Agent"]
|
||||||
|
pattern = r"^langchain-py-vs/\d+\.\d+\.\d+$"
|
||||||
|
match = re.match(pattern, user_agent)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
match is not None
|
||||||
|
), f"The string '{user_agent}' does not match the expected pattern."
|
||||||
|
|
||||||
|
def test_bulk_args(self, es_client: Any, index_name: str) -> None:
|
||||||
|
"""Test to make sure the user-agent is set correctly."""
|
||||||
|
|
||||||
|
texts = ["foo", "bob", "baz"]
|
||||||
|
ElasticsearchStore.from_texts(
|
||||||
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
es_connection=es_client,
|
||||||
|
index_name=index_name,
|
||||||
|
bulk_kwargs={"chunk_size": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1 for index exist, 1 for index create, 3 for index docs
|
||||||
|
assert len(es_client.transport.requests) == 5 # type: ignore
|
14
libs/partners/elasticsearch/tests/unit_tests/test_imports.py
Normal file
14
libs/partners/elasticsearch/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from langchain_elasticsearch import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = [
|
||||||
|
"ApproxRetrievalStrategy",
|
||||||
|
"ElasticsearchChatMessageHistory",
|
||||||
|
"ElasticsearchEmbeddings",
|
||||||
|
"ElasticsearchStore",
|
||||||
|
"ExactRetrievalStrategy",
|
||||||
|
"SparseRetrievalStrategy",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,34 @@
|
|||||||
|
"""Test Elasticsearch functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_elasticsearch.vectorstores import (
|
||||||
|
ApproxRetrievalStrategy,
|
||||||
|
ElasticsearchStore,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("elasticsearch")
|
||||||
|
def test_elasticsearch_hybrid_scores_guard() -> None:
|
||||||
|
"""Ensure an error is raised when search with score in hybrid mode
|
||||||
|
because in this case Elasticsearch does not return any score.
|
||||||
|
"""
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
query_string = "foo"
|
||||||
|
embeddings = FakeEmbeddings()
|
||||||
|
|
||||||
|
store = ElasticsearchStore(
|
||||||
|
index_name="dummy_index",
|
||||||
|
es_connection=Elasticsearch(hosts=["http://dummy-host:9200"]),
|
||||||
|
embedding=embeddings,
|
||||||
|
strategy=ApproxRetrievalStrategy(hybrid=True),
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
store.similarity_search_with_score(query_string)
|
||||||
|
|
||||||
|
embedded_query = embeddings.embed_query(query_string)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
store.similarity_search_by_vector_with_relevance_scores(embedded_query)
|
Loading…
Reference in New Issue
Block a user