mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 02:11:09 +00:00
elasticsearch: check for deployed models (#18973)
When creating a new index, if we use a retrieval strategy that expects a model to be deployed in Elasticsearch, check if a model with this name is indeed deployed before creating an index. This lowers the probability to get into a state in which an index was created with a faulty model ID, which cannot be overwritten any more (the index has to manually be deleted).
This commit is contained in:
parent
b82644078e
commit
6f544a6a25
@ -2,7 +2,7 @@ from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
|
||||
from langchain_core import __version__ as langchain_version
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
@ -88,3 +88,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
||||
|
||||
|
||||
def check_if_model_deployed(client: Elasticsearch, model_id: str) -> None:
|
||||
try:
|
||||
dummy = {"x": "y"}
|
||||
client.ml.infer_trained_model(model_id=model_id, docs=[dummy])
|
||||
except NotFoundError as err:
|
||||
raise err
|
||||
except ConflictError as err:
|
||||
raise NotFoundError(
|
||||
f"model '{model_id}' not found, please deploy it first",
|
||||
meta=err.meta,
|
||||
body=err.body,
|
||||
) from err
|
||||
except BadRequestError:
|
||||
# This error is expected because we do not know the expected document
|
||||
# shape and just use a dummy doc above.
|
||||
pass
|
||||
|
@ -22,6 +22,7 @@ from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_elasticsearch._utilities import (
|
||||
DistanceStrategy,
|
||||
check_if_model_deployed,
|
||||
maximal_marginal_relevance,
|
||||
with_user_agent_header,
|
||||
)
|
||||
@ -199,6 +200,12 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
|
||||
else:
|
||||
return {"knn": knn}
|
||||
|
||||
def before_index_setup(
|
||||
self, client: "Elasticsearch", text_field: str, vector_query_field: str
|
||||
) -> None:
|
||||
if self.query_model_id:
|
||||
check_if_model_deployed(client, self.query_model_id)
|
||||
|
||||
def index(
|
||||
self,
|
||||
dims_length: Union[int, None],
|
||||
@ -340,8 +347,10 @@ class SparseRetrievalStrategy(BaseRetrievalStrategy):
|
||||
def before_index_setup(
|
||||
self, client: "Elasticsearch", text_field: str, vector_query_field: str
|
||||
) -> None:
|
||||
# If model_id is provided, create a pipeline for the model
|
||||
if self.model_id:
|
||||
check_if_model_deployed(client, self.model_id)
|
||||
|
||||
# Create a pipeline for the model
|
||||
client.ingest.put_pipeline(
|
||||
id=self._get_pipeline_name(),
|
||||
description="Embedding pipeline for langchain vectorstore",
|
||||
|
@ -7,7 +7,7 @@ import uuid
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -40,7 +40,7 @@ Enable them by adding the model name to the modelsDeployed list below.
|
||||
"""
|
||||
|
||||
modelsDeployed: List[str] = [
|
||||
# "elser",
|
||||
# ".elser_model_1",
|
||||
# "sentence-transformers__all-minilm-l6-v2",
|
||||
]
|
||||
|
||||
@ -709,7 +709,7 @@ class TestElasticsearch:
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
"elser" not in modelsDeployed,
|
||||
".elser_model_1" not in modelsDeployed,
|
||||
reason="ELSER not deployed in ML Node, skipping test",
|
||||
)
|
||||
def test_similarity_search_with_sparse_infer_instack(
|
||||
@ -726,6 +726,35 @@ class TestElasticsearch:
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
def test_deployed_model_check_fails_approx(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""test that exceptions are raised if a specified model is not deployed"""
|
||||
with pytest.raises(NotFoundError):
|
||||
ElasticsearchStore.from_texts(
|
||||
texts=["foo", "bar", "baz"],
|
||||
embedding=ConsistentFakeEmbeddings(10),
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
|
||||
query_model_id="non-existing model ID",
|
||||
),
|
||||
)
|
||||
|
||||
def test_deployed_model_check_fails_sparse(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""test that exceptions are raised if a specified model is not deployed"""
|
||||
with pytest.raises(NotFoundError):
|
||||
ElasticsearchStore.from_texts(
|
||||
texts=["foo", "bar", "baz"],
|
||||
**elasticsearch_connection,
|
||||
index_name=index_name,
|
||||
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(
|
||||
model_id="non-existing model ID"
|
||||
),
|
||||
)
|
||||
|
||||
def test_elasticsearch_with_relevance_score(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user