community[patch]: Added type hinting to OpenSearch clients (#27946)

Description:
* When working with OpenSearchVectorSearch to make
OpenSearchGraphVectorStore (coming soon), I noticed that there wasn't
type hinting for the underlying OpenSearch clients. This fixes that
issue.
* Confirmed tests are still passing with code changes.

Note that there is some additional code duplication now, but I think
this approach is cleaner overall.
This commit is contained in:
Eric Pinzur 2024-11-08 20:04:57 +01:00 committed by GitHub
parent 4c2392e55c
commit c421997caa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import uuid import uuid
import warnings import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
@ -23,57 +23,18 @@ SCRIPT_SCORING_SEARCH = "script_scoring"
PAINLESS_SCRIPTING_SEARCH = "painless_scripting" PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
if TYPE_CHECKING:
def _import_opensearch() -> Any: from opensearchpy import AsyncOpenSearch, OpenSearch
"""Import OpenSearch if available, otherwise raise error."""
try:
from opensearchpy import OpenSearch
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return OpenSearch
def _import_async_opensearch() -> Any: def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch:
"""Import AsyncOpenSearch if available, otherwise raise error."""
try:
from opensearchpy import AsyncOpenSearch
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return AsyncOpenSearch
def _import_bulk() -> Any:
"""Import bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return bulk
def _import_async_bulk() -> Any:
"""Import async_bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return async_bulk
def _import_not_found_error() -> Any:
"""Import not found error if available, otherwise raise error."""
try:
from opensearchpy.exceptions import NotFoundError
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return NotFoundError
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
"""Get OpenSearch client from the opensearch_url, otherwise raise error.""" """Get OpenSearch client from the opensearch_url, otherwise raise error."""
try: try:
opensearch = _import_opensearch() from opensearchpy import OpenSearch
client = opensearch(opensearch_url, **kwargs)
client = OpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
except ValueError as e: except ValueError as e:
raise ImportError( raise ImportError(
f"OpenSearch client string provided is not in proper format. " f"OpenSearch client string provided is not in proper format. "
@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
return client return client
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch:
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error.""" """Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
try: try:
async_opensearch = _import_async_opensearch() from opensearchpy import AsyncOpenSearch
client = async_opensearch(opensearch_url, **kwargs)
client = AsyncOpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
except ValueError as e: except ValueError as e:
raise ImportError( raise ImportError(
f"AsyncOpenSearch client string provided is not in proper format. " f"AsyncOpenSearch client string provided is not in proper format. "
@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool:
def _bulk_ingest_embeddings( def _bulk_ingest_embeddings(
client: Any, client: OpenSearch,
index_name: str, index_name: str,
embeddings: List[List[float]], embeddings: List[List[float]],
texts: Iterable[str], texts: Iterable[str],
@ -142,16 +106,19 @@ def _bulk_ingest_embeddings(
"""Bulk Ingest Embeddings into given index.""" """Bulk Ingest Embeddings into given index."""
if not mapping: if not mapping:
mapping = dict() mapping = dict()
try:
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
bulk = _import_bulk()
not_found_error = _import_not_found_error()
requests = [] requests = []
return_ids = [] return_ids = []
mapping = mapping mapping = mapping
try: try:
client.indices.get(index=index_name) client.indices.get(index=index_name)
except not_found_error: except NotFoundError:
client.indices.create(index=index_name, body=mapping) client.indices.create(index=index_name, body=mapping)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -177,7 +144,7 @@ def _bulk_ingest_embeddings(
async def _abulk_ingest_embeddings( async def _abulk_ingest_embeddings(
client: Any, client: AsyncOpenSearch,
index_name: str, index_name: str,
embeddings: List[List[float]], embeddings: List[List[float]],
texts: Iterable[str], texts: Iterable[str],
@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings(
if not mapping: if not mapping:
mapping = dict() mapping = dict()
async_bulk = _import_async_bulk() try:
not_found_error = _import_not_found_error() from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
requests = [] requests = []
return_ids = [] return_ids = []
try: try:
await client.indices.get(index=index_name) await client.indices.get(index=index_name)
except not_found_error: except NotFoundError:
await client.indices.create(index=index_name, body=mapping) await client.indices.create(index=index_name, body=mapping)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings(
def _default_scripting_text_mapping( def _default_scripting_text_mapping(
dim: int, dim: int,
vector_field: str = "vector_field", vector_field: str = "vector_field",
) -> Dict: ) -> Dict[str, Any]:
"""For Painless Scripting or Script Scoring,the default mapping to create index.""" """For Painless Scripting or Script Scoring,the default mapping to create index."""
return { return {
"mappings": { "mappings": {
@ -249,7 +220,7 @@ def _default_text_mapping(
ef_construction: int = 512, ef_construction: int = 512,
m: int = 16, m: int = 16,
vector_field: str = "vector_field", vector_field: str = "vector_field",
) -> Dict: ) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default mapping to create index.""" """For Approximate k-NN Search, this is the default mapping to create index."""
return { return {
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}}, "settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
@ -275,7 +246,7 @@ def _default_approximate_search_query(
k: int = 4, k: int = 4,
vector_field: str = "vector_field", vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0, score_threshold: Optional[float] = 0.0,
) -> Dict: ) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default query.""" """For Approximate k-NN Search, this is the default query."""
return { return {
"size": k, "size": k,
@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter(
vector_field: str = "vector_field", vector_field: str = "vector_field",
subquery_clause: str = "must", subquery_clause: str = "must",
score_threshold: Optional[float] = 0.0, score_threshold: Optional[float] = 0.0,
) -> Dict: ) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Boolean Filter.""" """For Approximate k-NN Search, with Boolean Filter."""
return { return {
"size": k, "size": k,
@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter(
k: int = 4, k: int = 4,
vector_field: str = "vector_field", vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0, score_threshold: Optional[float] = 0.0,
) -> Dict: ) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Efficient Filter for Lucene and """For Approximate k-NN Search, with Efficient Filter for Lucene and
Faiss Engines.""" Faiss Engines."""
search_query = _default_approximate_search_query( search_query = _default_approximate_search_query(
@ -330,7 +301,7 @@ def _default_script_query(
pre_filter: Optional[Dict] = None, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0, score_threshold: Optional[float] = 0.0,
) -> Dict: ) -> Dict[str, Any]:
"""For Script Scoring Search, this is the default query.""" """For Script Scoring Search, this is the default query."""
if not pre_filter: if not pre_filter:
@ -376,7 +347,7 @@ def _default_painless_scripting_query(
pre_filter: Optional[Dict] = None, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0, score_threshold: Optional[float] = 0.0,
) -> Dict: ) -> Dict[str, Any]:
"""For Painless Scripting Search, this is the default query.""" """For Painless Scripting Search, this is the default query."""
if not pre_filter: if not pre_filter:
@ -692,7 +663,10 @@ class OpenSearchVectorSearch(VectorStore):
refresh_indices: Whether to refresh the index refresh_indices: Whether to refresh the index
after deleting documents. Defaults to True. after deleting documents. Defaults to True.
""" """
bulk = _import_bulk() try:
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
body = [] body = []